Compare commits

53 Commits

Author SHA1 Message Date
3e54662d7f 更新1.16.0版本 2025-09-15 15:52:30 +08:00
7ca1008461 优化镜像构建 2025-09-15 07:46:36 +00:00
a52899ca02 添加判断截图方法 2025-09-15 07:46:19 +00:00
0f0e666e67 关闭清晰度测试 2025-09-15 07:45:59 +00:00
bf1000a848 如果没有结算单号就填住院号 2025-09-15 07:45:25 +00:00
2656976efa 修正ie表更新的修改人 2025-08-20 02:57:31 +00:00
cae997fcf7 删除特殊调整的案子查询 2025-08-20 09:42:35 +08:00
e018250344 更新版本1.15.5 2025-08-19 19:10:12 +08:00
c0f5ca2eb4 临时调整识别案子范围,尝试修正模型替换导致的识别率下降 2025-08-19 19:07:29 +08:00
ffed64b0b9 临时调整识别案子范围,尝试修正模型替换导致的识别率下降 2025-08-19 18:34:32 +08:00
1625f0294f 从主分支删除开发环境相关内容 2025-08-19 10:11:12 +08:00
f19f8cbcae 复制项目内容 2025-08-18 16:38:53 +08:00
ba3e23d185 修正grep警告 2025-08-18 16:01:01 +08:00
0abf7abb5b 修正语言包警告 2025-08-18 15:15:29 +08:00
47ac6aadbe 修正语言包警告 2025-08-18 15:10:39 +08:00
09ede1af25 删除doc_dewarp扭曲矫正 2025-08-18 14:21:10 +08:00
e40d963bf5 调整开发环境镜像构建 2025-08-18 14:05:38 +08:00
34344edd29 调整开发环境镜像构建 2025-08-18 13:58:43 +08:00
b387db1e08 添加远程开发环境容器 2025-08-18 13:29:23 +08:00
88ca27928f 删除过期数据 2025-08-13 10:50:15 +08:00
109a5e9444 删除过期数据 2025-08-12 11:13:14 +08:00
ab5f78cc7b 测试图片清晰度模型效果 2025-07-30 15:11:01 +08:00
04358ee646 增加图片损坏的判断 2025-06-26 08:58:57 +08:00
a67c53f470 增加图片损坏的判断 2025-06-26 08:48:11 +08:00
cd604bc1eb 修正图片可能因为某些原因获取不到而无法继续的问题 2025-05-21 11:26:01 +08:00
0de9fc14b5 修正高清图片为空的判断 2025-04-01 15:24:32 +08:00
5287df4959 添加图片清晰度测试,保存结果对照 2025-04-01 15:11:40 +08:00
3e9c0c99b9 Revert "测试图片清晰度"
This reverts commit a740f16e6b.
2025-04-01 14:47:56 +08:00
a740f16e6b 测试图片清晰度 2025-04-01 14:27:54 +08:00
b9606771cf 判断图片清晰度 2025-04-01 14:24:19 +08:00
110bc57abc Revert "测试清晰度模型运行情况"
This reverts commit f965bc4289.
2025-03-20 16:15:33 +08:00
f965bc4289 测试清晰度模型运行情况 2025-03-20 15:26:54 +08:00
8b6bf03d76 二维码识别替换前备份 2025-03-20 10:23:04 +08:00
73536aea89 删除一个识别服务 2025-02-17 16:08:30 +08:00
12f6554d8c 修正全角数字存入数据库失败的问题 2025-02-17 13:00:52 +08:00
be94bc7f09 删除一个识别服务 2025-02-17 10:41:54 +08:00
8307fcd549 添加两个识别服务 2025-02-17 10:09:11 +08:00
ab2dbf7c15 删除两个涂抹服务 2025-02-08 10:59:18 +08:00
dac7a1b5ce 添加两个涂抹服务 2025-02-08 09:55:17 +08:00
935fa26067 修正无效方法调用 2025-01-22 16:25:04 +08:00
16ab4c78d5 调整各字段长度限制,与其他表长度保持一致 2025-01-22 16:19:52 +08:00
ed77b8ed82 调整医生姓名长度限制 2025-01-22 15:44:14 +08:00
f4ec3b1eb4 paddlelite2.10可成功转换 2025-01-15 16:47:21 +08:00
a93f68413c 批量下载图片 2025-01-15 16:44:32 +08:00
0a947274dc paddle模型转nb端侧部署模型 2025-01-08 16:04:57 +08:00
f7540c4574 修正计算误差导致分割产生无效图片 2025-01-02 13:05:10 +08:00
5e6a471954 增加ocr结果存表 2024-12-24 14:55:43 +08:00
96b8a06e6c 优化医生名字的处理,没有时填充无 2024-12-03 13:27:18 +08:00
be27f753ba 优化数据字段的处理,增加限制条件 2024-12-03 13:12:58 +08:00
8ea5420520 捕获到内存错误主动抛出中止程序,以释放显存 2024-11-28 10:27:44 +08:00
9749577c5a 优化docker服务,删除无效的det_api 2024-11-05 17:30:16 +08:00
e0f6b82dad 更新版本号 2024-11-05 17:21:37 +08:00
8223759fdf 调整自动识别的文字识别模型 2024-11-05 17:18:12 +08:00
37 changed files with 403 additions and 2271 deletions

View File

@@ -15,6 +15,7 @@ ENV PYTHONUNBUFFERED=1 \
COPY requirements.txt /app/requirements.txt
COPY packages /app/packages
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
&& python3 -m pip install --upgrade pip \
&& pip install --no-cache-dir -r requirements.txt \
&& pip uninstall -y onnxruntime onnxruntime-gpu \
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/

View File

@@ -125,4 +125,9 @@ bash update.sh
1. 新增文档检测功能
2. 新增扭曲矫正功能
21. 版本号1.14.0
1. 新增二维码识别替换高清图片功能
1. 新增二维码识别替换高清图片功能
22. 版本号1.15.0
1. 新增图片清晰度测试
23. 版本号1.16.0
1. 优化结算单号规则
2. 新增判断截图方法

View File

@@ -19,5 +19,5 @@ DB_URL = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}'
SHOW_SQL = False
Engine = create_engine(DB_URL, echo=SHOW_SQL)
Base = declarative_base(Engine)
Base = declarative_base()
MysqlSession = sessionmaker(bind=Engine)

View File

@@ -1,5 +1,5 @@
# coding: utf-8
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary, Text
from sqlalchemy.dialects.mysql import BIT, CHAR, INTEGER, TINYINT, VARCHAR
from db import Base
@@ -56,13 +56,16 @@ class ZxIeCost(Base):
pk_ie_cost = Column(INTEGER(11), primary_key=True, comment='费用明细信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
name = Column(String(30), comment='患者姓名')
content = Column(Text, comment='详细内容')
name = Column(String(20), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
discharge_date = Column(Date, comment='出院日期')
medical_expenses_str = Column(String(255), comment='费用总额字符串')
medical_expenses = Column(DECIMAL(18, 2), comment='费用总额')
page_nums = Column(String(255), comment='页码')
page_count = Column(TINYINT(4), comment='页数')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -94,19 +97,19 @@ class ZxIeDischarge(Base):
pk_ie_discharge = Column(INTEGER(11), primary_key=True, comment='出院记录信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
content = Column(String(5000), comment='详细内容')
hospital = Column(String(255), comment='医院')
content = Column(Text, comment='详细内容')
hospital = Column(String(200), comment='医院')
pk_yljg = Column(INTEGER(11), comment='医院主键')
department = Column(String(255), comment='科室')
department = Column(String(200), comment='科室')
pk_ylks = Column(INTEGER(11), comment='科室主键')
name = Column(String(30), comment='患者姓名')
name = Column(String(20), comment='患者姓名')
age = Column(INTEGER(3), comment='年龄')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
discharge_date = Column(Date, comment='出院日期')
doctor = Column(String(30), comment='主治医生')
admission_id = Column(String(50), comment='住院号')
doctor = Column(String(20), comment='主治医生')
admission_id = Column(String(20), comment='住院号')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
@@ -138,7 +141,7 @@ class ZxIeSettlement(Base):
pk_ie_settlement = Column(INTEGER(11), primary_key=True, comment='结算清单信息抽取主键')
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
name = Column(String(30), comment='患者姓名')
name = Column(String(20), comment='患者姓名')
admission_date_str = Column(String(255), comment='入院日期字符串')
admission_date = Column(Date, comment='入院日期')
discharge_date_str = Column(String(255), comment='出院日期字符串')
@@ -152,9 +155,9 @@ class ZxIeSettlement(Base):
personal_funded_amount_str = Column(String(255), comment='自费金额字符串')
personal_funded_amount = Column(DECIMAL(18, 2), comment='自费金额')
medical_insurance_type_str = Column(String(255), comment='医保类型字符串')
medical_insurance_type = Column(String(40), comment='医保类型')
admission_id = Column(String(50), comment='住院号')
settlement_id = Column(String(50), comment='医保结算单号码')
medical_insurance_type = Column(String(10), comment='医保类型')
admission_id = Column(String(20), comment='住院号')
settlement_id = Column(String(30), comment='医保结算单号码')
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
creator = Column(String(255), comment='创建人')
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),

74
delete_deprecated_data.py Normal file
View File

@@ -0,0 +1,74 @@
# 删除本地数据库中的过期数据
import logging.config
from datetime import datetime, timedelta
from db import MysqlSession
from db.mysql import ZxPhhd, ZxIeCost, ZxIeDischarge, ZxIeResult, ZxIeSettlement, ZxPhrec
from log import LOGGING_CONFIG
# 过期时间(不建议小于1个月)
EXPIRATION_DAYS = 183
# 批量删除数量(最好在1000~10000之间)
BATCH_SIZE = 5000
# 数据库会话对象
session = None
def batch_delete_by_pk_phhd(model, pk_phhds):
"""
批量删除指定模型中主键在指定列表中的数据
参数:
modelSQLAlchemy模型类对应数据库表
pk_phhds待删除的主键值列表
返回:
删除的记录数量
"""
delete_count = (
session.query(model)
.filter(model.pk_phhd.in_(pk_phhds))
.delete(synchronize_session=False)
)
session.commit()
logging.getLogger("sql").info(f"{model.__tablename__}成功删除{delete_count}条数据")
return delete_count
if __name__ == '__main__':
logging.config.dictConfig(LOGGING_CONFIG)
deadline = datetime.now() - timedelta(days=EXPIRATION_DAYS)
double_deadline = deadline - timedelta(days=EXPIRATION_DAYS)
session = MysqlSession()
try:
while 1:
# 已经走完所有流程的案子,超过过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.paint_flag == "9")
.filter(ZxPhhd.billdate < deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有通过审核,可能会重拍补拍上传的案子,超过两倍过期时间后删除
phhds = (session.query(ZxPhhd.pk_phhd)
.filter(ZxPhhd.exsuccess_flag == "9")
.filter(ZxPhhd.paint_flag == "0")
.filter(ZxPhhd.billdate < double_deadline)
.limit(BATCH_SIZE)
.all())
if not phhds or len(phhds) <= 0:
# 没有符合条件的数据,退出循环
break
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
logging.getLogger("sql").info(f"过期的pk_phhd有{','.join(map(str, pk_phhd_values))}")
batch_delete_by_pk_phhd(ZxPhrec, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeResult, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeSettlement, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeDischarge, pk_phhd_values)
batch_delete_by_pk_phhd(ZxIeCost, pk_phhd_values)
batch_delete_by_pk_phhd(ZxPhhd, pk_phhd_values)
except Exception as e:
session.rollback()
logging.getLogger('error').error('过期数据删除失败!', exc_info=e)
finally:
session.close()

167
doc_dewarp/.gitignore vendored
View File

@@ -1,167 +0,0 @@
input/
output/
runs/
*_dump/
*_log/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

View File

@@ -1,34 +0,0 @@
repos:
# Common hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
- id: remove-tabs
name: Tabs remover (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
# For Python files
- repo: https://github.com/psf/black.git
rev: 23.3.0
hooks:
- id: black
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]

View File

@@ -1,398 +0,0 @@
import copy
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .extractor import BasicEncoder
from .position_encoding import build_position_encoding
from .weight_init import weight_init_
class attnLayer(nn.Layer):
def __init__(
self,
d_model,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn_list = nn.LayerList(
[
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
for i in range(2)
]
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2_list = nn.LayerList(
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(p=dropout)
self.dropout2_list = nn.LayerList(
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
)
self.dropout3 = nn.Dropout(p=dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
q = k = self.with_pos_embed(tgt, pos)
tgt2 = self.self_attn(
q.transpose((1, 0, 2)),
k.transpose((1, 0, 2)),
value=tgt.transpose((1, 0, 2)),
attn_mask=tgt_mask,
)
tgt2 = tgt2.transpose((1, 0, 2))
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
memory_list,
self.multihead_attn_list,
self.norm2_list,
self.dropout2_list,
memory_pos,
):
tgt2 = multihead_attn(
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
value=memory.transpose((1, 0, 2)),
attn_mask=memory_mask,
).transpose((1, 0, 2))
tgt = tgt + dropout2(tgt2)
tgt = norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=memory_mask,
)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
return self.forward_post(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransDecoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransDecoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return query_embed
class TransEncoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransEncoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return image
class FlowHead(nn.Layer):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class UpdateBlock(nn.Layer):
def __init__(self, hidden_dim: int = 128):
super(UpdateBlock, self).__init__()
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2D(hidden_dim, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2D(256, 64 * 9, 1, padding=0),
)
def forward(self, image, coords):
mask = 0.25 * self.mask(image)
dflow = self.flow_head(image)
coords = coords + dflow
return mask, coords
def coords_grid(batch, ht, wd):
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
return coords[None].tile([batch, 1, 1, 1])
def upflow8(flow, mode="bilinear"):
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
class OverlapPatchEmbed(nn.Layer):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
patch_size = (
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
)
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2D(
in_chans,
embed_dim,
patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
weight_init_(m, "trunc_normal_", std=0.02)
elif isinstance(m, nn.LayerNorm):
weight_init_(m, "Constant", value=1.0)
elif isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2)
perm = list(range(x.ndim))
perm[1] = 2
perm[2] = 1
x = x.transpose(perm=perm)
x = self.norm(x)
return x, H, W
class GeoTr(nn.Layer):
def __init__(self):
super(GeoTr, self).__init__()
self.hidden_dim = hdim = 256
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
for i in self.encoder_block:
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
for i in self.down_layer:
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
for i in self.decoder_block:
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
for i in self.up_layer:
self.__setattr__(
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.query_embed = nn.Embedding(81, self.hidden_dim)
self.update_block = UpdateBlock(self.hidden_dim)
def initialize_flow(self, img):
N, _, H, W = img.shape
coodslar = coords_grid(N, H, W)
coords0 = coords_grid(N, H // 8, W // 8)
coords1 = coords_grid(N, H // 8, W // 8)
return coodslar, coords0, coords1
def upsample_flow(self, flow, mask):
N, _, H, W = flow.shape
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
mask = F.softmax(mask, axis=2)
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
up_flow = paddle.sum(mask * up_flow, axis=2)
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
return up_flow.reshape([N, 2, 8 * H, 8 * W])
def forward(self, image):
fmap = self.fnet(image)
fmap = F.relu(fmap)
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
fmap3du_ = (
self.__getattr__(self.up_layer[0])(fmap3d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
fmap2du_ = (
self.__getattr__(self.up_layer[1])(fmap2d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
coodslar, coords0, coords1 = self.initialize_flow(image)
coords1 = coords1.detach()
mask, coords1 = self.update_block(fmap_out, coords1)
flow_up = self.upsample_flow(coords1 - coords0, mask)
bm_up = coodslar + flow_up
return bm_up

View File

@@ -1,73 +0,0 @@
# DocTrPP: DocTr++ in PaddlePaddle
## Introduction
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
![demo](https://github.com/GreatV/DocTrPP/assets/17264618/4e491512-bfc4-4e69-a833-fd1c6e17158c)
## Requirements
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
## Training
1. Data Preparation
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
2. Training
```shell
sh train.sh
```
or
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--exist-ok \
--use-vdl
```
3. Load Trained Model and Continue Training
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--resume "runs/train/DocTr++/weights/last.ckpt" \
--exist-ok \
--use-vdl
```
## Test and Inference
Test the dewarp result on a single image:
```shell
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
```
![document image rectification](https://raw.githubusercontent.com/greatv/DocTrPP/main/doc/imgs/document_image_rectification.jpg)
## Export to onnx
```
pip install paddle2onnx
python export.py -m ./best.ckpt --format onnx
```
## Model Download
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).

View File

@@ -1,4 +0,0 @@
from onnxruntime import InferenceSession
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])

View File

@@ -1,133 +0,0 @@
import argparse
import os
import random
import cv2
import hdf5storage as h5
import matplotlib.pyplot as plt
import numpy as np
import paddle
import paddle.nn.functional as F
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
nargs="?",
type=str,
default="~/datasets/doc3d/",
help="Path to the downloaded dataset",
)
parser.add_argument(
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
)
parser.add_argument(
"--output",
nargs="?",
type=str,
default="output.png",
help="Output filename for the image",
)
args = parser.parse_args()
root = os.path.expanduser(args.data_root)
folder = args.folder
dirname = os.path.join(root, "img", str(folder))
choices = [f for f in os.listdir(dirname) if "png" in f]
fname = random.choice(choices)
# Read Image
img_path = os.path.join(dirname, fname)
img = cv2.imread(img_path)
# Read 3D Coords
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# scale wc
# value obtained from the entire dataset
xmx, xmn, ymx, ymn, zmx, zmn = (
1.2539363,
-1.2442188,
1.2396319,
-1.2289206,
0.6436657,
-0.67492497,
)
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
# Read Backward Map
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
bm = h5.loadmat(bm_path)["bm"]
# Read UV Map
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Depth Map
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
# do some clipping and scaling to display it
dmap[dmap > 30.0] = 30
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
# Read Normal Map
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Albedo
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
alb = cv2.imread(alb_path)
# Read Checkerboard Image
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
recon = cv2.imread(recon_path)
# Display image and GTs
# use the backward mapping to dewarp the image
# scale bm to -1.0 to 1.0
bm_ = bm / np.array([448, 448])
bm_ = (bm_ - 0.5) * 2
bm_ = np.reshape(bm_, (1, 448, 448, 2))
bm_ = paddle.to_tensor(bm_, dtype="float32")
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
img_ = np.expand_dims(img_, 0)
img_ = paddle.to_tensor(img_, dtype="float32")
uw = F.grid_sample(img_, bm_)
uw = uw[0].numpy().transpose((1, 2, 0))
f, axrr = plt.subplots(2, 5)
for ax in axrr:
for a in ax:
a.set_xticks([])
a.set_yticks([])
axrr[0][0].imshow(img)
axrr[0][0].title.set_text("image")
axrr[0][1].imshow(wc)
axrr[0][1].title.set_text("3D coords")
axrr[0][2].imshow(bm[:, :, 0])
axrr[0][2].title.set_text("bm 0")
axrr[0][3].imshow(bm[:, :, 1])
axrr[0][3].title.set_text("bm 1")
if uv is None:
uv = np.zeros_like(img)
axrr[0][4].imshow(uv)
axrr[0][4].title.set_text("uv map")
axrr[1][0].imshow(dmap)
axrr[1][0].title.set_text("depth map")
axrr[1][1].imshow(norm)
axrr[1][1].title.set_text("normal map")
axrr[1][2].imshow(alb)
axrr[1][2].title.set_text("albedo")
axrr[1][3].imshow(recon)
axrr[1][3].title.set_text("checkerboard")
axrr[1][4].imshow(uw)
axrr[1][4].title.set_text("gt unwarped")
plt.tight_layout()
plt.savefig(args.output)

View File

@@ -1,21 +0,0 @@
import cv2
import numpy as np
import paddle
from . import DOC_TR
from .utils import to_tensor, to_image
def dewarp_image(image):
img = cv2.resize(image, (288, 288)).astype(np.float32)
y = to_tensor(image)
img = np.transpose(img, (2, 0, 1))
bm = DOC_TR.run(None, {"image": img[None,]})[0]
bm = paddle.to_tensor(bm)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
return to_image(out)

View File

@@ -1,161 +0,0 @@
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"

Binary file not shown.

Before

Width:  |  Height:  |  Size: 76 KiB

View File

@@ -1,129 +0,0 @@
import collections
import os
import random
import albumentations as A
import cv2
import hdf5storage as h5
import numpy as np
import paddle
from paddle import io
# Set random seed
random.seed(12345678)
class Doc3dDataset(io.Dataset):
def __init__(self, root, split="train", is_augment=False, image_size=512):
self.root = os.path.expanduser(root)
self.split = split
self.is_augment = is_augment
self.files = collections.defaultdict(list)
self.image_size = (
image_size if isinstance(image_size, tuple) else (image_size, image_size)
)
# Augmentation
self.augmentation = A.Compose(
[
A.ColorJitter(),
]
)
for split in ["train", "val"]:
path = os.path.join(self.root, split + ".txt")
file_list = []
with open(path, "r") as file:
file_list = [file_id.rstrip() for file_id in file.readlines()]
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
image_name = self.files[self.split][index]
# Read image
image_path = os.path.join(self.root, "img", image_name + ".png")
image = cv2.imread(image_path)
# Read 3D Coordinates
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read backward map
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
bm = h5.loadmat(bm_path)["bm"]
image, bm = self.transform(wc, bm, image)
return image, bm
def tight_crop(self, wc: np.ndarray):
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
np.uint8
)
mask_size = mask.shape
[y, x] = mask.nonzero()
min_x = min(x)
max_x = max(x)
min_y = min(y)
max_y = max(y)
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
s = 10
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
cx1 = random.randint(0, 2 * s)
cx2 = random.randint(0, 2 * s) + 1
cy1 = random.randint(0, 2 * s)
cy2 = random.randint(0, 2 * s) + 1
wc = wc[cy1:-cy2, cx1:-cx2, :]
top: int = min_y - s + cy1
bottom: int = mask_size[0] - max_y - s + cy2
left: int = min_x - s + cx1
right: int = mask_size[1] - max_x - s + cx2
top = np.clip(top, 0, mask_size[0])
bottom = np.clip(bottom, 1, mask_size[0] - 1)
left = np.clip(left, 0, mask_size[1])
right = np.clip(right, 1, mask_size[1] - 1)
return wc, top, bottom, left, right
def transform(self, wc, bm, img):
wc, top, bottom, left, right = self.tight_crop(wc)
img = img[top:-bottom, left:-right, :]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.is_augment:
img = self.augmentation(image=img)["image"]
# resize image
img = cv2.resize(img, self.image_size)
img = img.astype(np.float32) / 255.0
img = img.transpose(2, 0, 1)
# resize bm
bm = bm.astype(np.float32)
bm[:, :, 1] = bm[:, :, 1] - top
bm[:, :, 0] = bm[:, :, 0] - left
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
bm0 = bm0 * self.image_size[0]
bm1 = bm1 * self.image_size[1]
bm = np.stack([bm0, bm1], axis=-1)
img = paddle.to_tensor(img).astype(dtype="float32")
bm = paddle.to_tensor(bm).astype(dtype="float32")
return img, bm

View File

@@ -1,66 +0,0 @@
import argparse
import os
import paddle
from GeoTr import GeoTr
def export(args):
model_path = args.model
imgsz = args.imgsz
format = args.format
model = GeoTr()
checkpoint = paddle.load(model_path)
model.set_state_dict(checkpoint["model"])
model.eval()
dirname = os.path.dirname(model_path)
if format == "static" or format == "onnx":
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
],
full_graph=True,
)
paddle.jit.save(model, os.path.join(dirname, "model"))
if format == "onnx":
onnx_path = os.path.join(dirname, "model.onnx")
os.system(
f"paddle2onnx --model_dir {dirname}"
" --model_filename model.pdmodel"
" --params_filename model.pdiparams"
f" --save_file {onnx_path}"
" --opset_version 11"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export model")
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--imgsz", type=int, default=288, help="The size of input image"
)
parser.add_argument(
"--format",
type=str,
default="static",
help="The format of exported model, which can be static or onnx",
)
args = parser.parse_args()
export(args)

View File

@@ -1,110 +0,0 @@
import paddle.nn as nn
from .weight_init import weight_init_
class ResidualBlock(nn.Layer):
"""Residual Block with custom normalization."""
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
self.relu = nn.ReLU()
if norm_fn == "group":
num_groups = planes // 8
self.norm1 = nn.GroupNorm(num_groups, planes)
self.norm2 = nn.GroupNorm(num_groups, planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups, planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(planes)
self.norm2 = nn.BatchNorm2D(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2D(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(planes)
self.norm2 = nn.InstanceNorm2D(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2D(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Layer):
"""Basic Encoder with custom normalization."""
def __init__(self, output_dim=128, norm_fn="batch"):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(8, 64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
self.relu1 = nn.ReLU()
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(128, stride=2)
self.layer3 = self._make_layer(192, stride=2)
self.conv2 = nn.Conv2D(192, output_dim, 1)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
weight_init_(m, "Constant", value=1, bias_value=0.0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = layer1, layer2
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
return x

View File

@@ -1,48 +0,0 @@
import copy
import os
import matplotlib.pyplot as plt
import paddle.optimizer as optim
from GeoTr import GeoTr
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
"""
Plot the learning rate scheduler
"""
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
lr = []
for _ in range(epochs):
for _ in range(30):
lr.append(scheduler.get_lr())
optimizer.step()
scheduler.step()
epoch = [float(i) / 30.0 for i in range(len(lr))]
plt.figure()
plt.plot(epoch, lr, ".-", label="Learning Rate")
plt.xlabel("epoch")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Scheduler")
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
plt.close()
if __name__ == "__main__":
model = GeoTr()
schaduler = optim.lr.OneCycleLR(
max_learning_rate=1e-4,
total_steps=1950,
phase_pct=0.1,
end_learning_rate=1e-4 / 2.5e5,
)
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
plot_lr_scheduler(
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
)

View File

@@ -1,124 +0,0 @@
import math
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.initializer as init
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
self.tensors = tensors
self.mask = mask
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
class PositionEmbeddingSine(nn.Layer):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, mask):
assert mask is not None
y_embed = mask.cumsum(axis=1, dtype="float32")
x_embed = mask.cumsum(axis=2, dtype="float32")
if self.normalize:
eps = 1e-06
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = paddle.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
return pos
class PositionEmbeddingLearned(nn.Layer):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
init_Constant = init.Uniform()
init_Constant(self.row_embed.weight)
init_Constant(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = paddle.arange(end=w)
j = paddle.arange(end=h)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
paddle.concat(
[
x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]),
],
axis=-1,
)
.transpose([2, 0, 1])
.unsqueeze(0)
.tile([x.shape[0], 1, 1, 1])
)
return pos
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
N_steps = hidden_dim // 2
if position_embedding in ("v2", "sine"):
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {position_embedding}")
return position_embedding

View File

@@ -1,69 +0,0 @@
import argparse
import cv2
import paddle
from GeoTr import GeoTr
from utils import to_image, to_tensor
def run(args):
image_path = args.image
model_path = args.model
output_path = args.output
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
model = GeoTr()
model.set_state_dict(state_dict)
model.eval()
img_org = cv2.imread(image_path)
img = cv2.resize(img_org, (288, 288))
x = to_tensor(img)
y = to_tensor(img_org)
bm = model(x)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = bm.transpose([0, 2, 3, 1])
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
out_image = to_image(out)
cv2.imwrite(output_path, out_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="predict")
parser.add_argument(
"--image",
"-i",
nargs="?",
type=str,
default="",
help="The path of image",
)
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--output",
"-o",
nargs="?",
type=str,
default="",
help="The path of output",
)
args = parser.parse_args()
print(args)
run(args)

View File

@@ -1,7 +0,0 @@
hdf5storage
loguru
numpy
scipy
opencv-python
matplotlib
albumentations

View File

@@ -1,57 +0,0 @@
import argparse
import glob
import os
import random
from pathlib import Path
from loguru import logger
random.seed(1234567)
def run(args):
data_root = os.path.expanduser(args.data_root)
ratio = args.train_ratio
data_path = os.path.join(data_root, "img", "*", "*.png")
img_list = glob.glob(data_path, recursive=True)
sorted(img_list)
random.shuffle(img_list)
train_size = int(len(img_list) * ratio)
train_text_path = os.path.join(data_root, "train.txt")
with open(train_text_path, "w") as file:
for item in img_list[:train_size]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
val_text_path = os.path.join(data_root, "val.txt")
with open(val_text_path, "w") as file:
for item in img_list[train_size:]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
logger.info(f"TRAIN LABEL: {train_text_path}")
logger.info(f"VAL LABEL: {val_text_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="~/datasets/doc3d",
help="Data path to load data",
)
parser.add_argument(
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
)
args = parser.parse_args()
logger.info(args)
run(args)

View File

@@ -1,381 +0,0 @@
import argparse
import inspect
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from loguru import logger
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle_msssim import ms_ssim, ssim
from doc3d_dataset import Doc3dDataset
from GeoTr import GeoTr
from utils import to_image
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
RANK = int(os.getenv("RANK", -1))
def init_seeds(seed=0, deterministic=False):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if deterministic:
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
# i.e. colorstr('blue', 'hello world')
*args, string = (
input if len(input) > 1 else ("blue", "bold", input[0])
) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path,
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
)
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def train(args):
save_dir = Path(args.save_dir)
use_vdl = args.use_vdl
if use_vdl:
from visualdl import LogWriter
log_dir = save_dir / "vdl"
vdl_writer = LogWriter(str(log_dir))
# Directories
weights_dir = save_dir / "weights"
weights_dir.parent.mkdir(parents=True, exist_ok=True)
last = weights_dir / "last.ckpt"
best = weights_dir / "best.ckpt"
# Hyperparameters
# Config
init_seeds(args.seed)
# Train loader
train_dataset = Doc3dDataset(
args.data_root,
split="train",
is_augment=True,
image_size=args.img_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
)
# Validation loader
val_dataset = Doc3dDataset(
args.data_root,
split="val",
is_augment=False,
image_size=args.img_size,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=args.workers
)
# Model
model = GeoTr()
if use_vdl:
vdl_writer.add_graph(
model,
input_spec=[
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
],
)
# Data Parallel Mode
if RANK == -1 and paddle.device.cuda.device_count() > 1:
model = paddle.DataParallel(model)
# Scheduler
scheduler = optim.lr.OneCycleLR(
max_learning_rate=args.lr,
total_steps=args.epochs * len(train_loader),
phase_pct=0.1,
end_learning_rate=args.lr / 2.5e5,
)
# Optimizer
optimizer = optim.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
)
# loss function
l1_loss_fn = nn.L1Loss()
mse_loss_fn = nn.MSELoss()
# Resume
best_fitness, start_epoch = 0.0, 0
if args.resume:
ckpt = paddle.load(args.resume)
model.set_state_dict(ckpt["model"])
optimizer.set_state_dict(ckpt["optimizer"])
scheduler.set_state_dict(ckpt["scheduler"])
best_fitness = ckpt["best_fitness"]
start_epoch = ckpt["epoch"] + 1
# Train
for epoch in range(start_epoch, args.epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img = paddle.to_tensor(img) # NCHW
target = paddle.to_tensor(target) # NHWC
pred = model(img) # NCHW
pred_nhwc = pred.transpose([0, 2, 3, 1])
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
if use_vdl:
vdl_writer.add_scalar(
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/Learning Rate",
float(scheduler.get_lr()),
epoch * len(train_loader) + i,
)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
if i % 10 == 0:
logger.info(
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
)
# Validation
model.eval()
with paddle.no_grad():
avg_ssim = paddle.zeros([])
avg_ms_ssim = paddle.zeros([])
avg_l1_loss = paddle.zeros([])
avg_mse_loss = paddle.zeros([])
for i, (img, target) in enumerate(val_loader):
img = paddle.to_tensor(img)
target = paddle.to_tensor(target)
pred = model(img)
pred_nhwc = pred.transpose([0, 2, 3, 1])
# predict image
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
# calculate ssim
ssim_val = ssim(out, out_gt, data_range=1.0)
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
# calculate fitness
avg_ssim += ssim_val
avg_ms_ssim += ms_ssim_val
avg_l1_loss += loss
avg_mse_loss += mse_loss
if i % 10 == 0:
logger.info(
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
)
if use_vdl and i == 0:
img_0 = to_image(out[0])
img_gt_0 = to_image(out_gt[0])
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
img_1 = to_image(out[1])
img_gt_1 = to_image(out_gt[1])
img_gt_1 = img_gt_1.astype("uint8")
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
img_2 = to_image(out[2])
img_gt_2 = to_image(out_gt[2])
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
avg_ssim /= len(val_loader)
avg_ms_ssim /= len(val_loader)
avg_l1_loss /= len(val_loader)
avg_mse_loss /= len(val_loader)
if use_vdl:
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
# Save
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_fitness": best_fitness,
"epoch": epoch,
}
paddle.save(ckpt, str(last))
if best_fitness < avg_ssim:
best_fitness = avg_ssim
paddle.save(ckpt, str(best))
if use_vdl:
vdl_writer.close()
def main(args):
print_args(vars(args))
args.save_dir = str(
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
)
train(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--data-root",
nargs="?",
type=str,
default="~/datasets/doc3d",
help="The root path of the dataset",
)
parser.add_argument(
"--img-size",
nargs="?",
type=int,
default=288,
help="The size of the input image",
)
parser.add_argument(
"--epochs",
nargs="?",
type=int,
default=65,
help="The number of training epochs",
)
parser.add_argument(
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
)
parser.add_argument(
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
)
parser.add_argument(
"--resume",
nargs="?",
type=str,
default=None,
help="Path to previous saved model to restart from",
)
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
parser.add_argument(
"--project", default=ROOT / "runs/train", help="save to project/name"
)
parser.add_argument("--name", default="exp", help="save to project/name")
parser.add_argument(
"--exist-ok",
action="store_true",
help="existing project/name ok, do not increment",
)
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
args = parser.parse_args()
main(args)

View File

@@ -1,10 +0,0 @@
export OPENCV_IO_ENABLE_OPENEXR=1
export FLAGS_logtostderr=0
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 1e-4 \
--exist-ok \
--use-vdl

View File

@@ -1,67 +0,0 @@
import numpy as np
import paddle
from paddle.nn import functional as F
def to_tensor(img: np.ndarray):
"""
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
Args:
img (numpy.ndarray): The input image as a numpy array.
Returns:
out (paddle.Tensor): The output tensor.
"""
img = img[:, :, ::-1]
img = img.astype("float32") / 255.0
img = img.transpose(2, 0, 1)
out: paddle.Tensor = paddle.to_tensor(img)
out = paddle.unsqueeze(out, axis=0)
return out
def to_image(x: paddle.Tensor):
"""
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
Args:
x (paddle.Tensor): The input tensor.
Returns:
out (numpy.ndarray): The output image as a numpy array.
"""
out: np.ndarray = x.squeeze().numpy()
out = out.transpose(1, 2, 0)
out = out * 255.0
out = out.astype("uint8")
out = out[:, :, ::-1]
return out
def unwarp(img, bm, bm_data_format="NCHW"):
"""
Unwarp an image using a flow field.
Args:
img (paddle.Tensor): The input image.
bm (paddle.Tensor): The flow field.
Returns:
out (paddle.Tensor): The output image.
"""
_, _, h, w = img.shape
if bm_data_format == "NHWC":
bm = bm.transpose([0, 3, 1, 2])
# NCHW
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
# NHWC
bm = bm.transpose([0, 2, 3, 1])
# NCHW
out = F.grid_sample(img, bm)
return out

View File

@@ -1,152 +0,0 @@
import math
import numpy as np
import paddle
import paddle.nn.initializer as init
from scipy import special
def weight_init_(
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
):
"""
In-place params init function.
Usage:
.. code-block:: python
import paddle
import numpy as np
data = np.ones([3, 4], dtype='float32')
linear = paddle.nn.Linear(4, 4)
input = paddle.to_tensor(data)
print(linear.weight)
linear(input)
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
print(linear.weight)
"""
if hasattr(layer, "weight") and layer.weight is not None:
getattr(init, func)(**kwargs)(layer.weight)
if weight_name is not None:
# override weight name
layer.weight.name = weight_name
if hasattr(layer, "bias") and layer.bias is not None:
init.Constant(bias_value)(layer.bias)
if bias_name is not None:
# override bias name
layer.bias.name = bias_name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with paddle.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
tmp = np.random.uniform(
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
).astype(np.float32)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tmp = special.erfinv(tmp)
# Transform to proper mean, std
tmp *= std * math.sqrt(2.0)
tmp += mean
# Clamp to ensure it's in the proper range
tmp = np.clip(tmp, a, b)
tensor.set_value(paddle.to_tensor(tmp))
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor "
"with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
)
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with paddle.no_grad():
paddle.nn.initializer.Normal(0, std)(tensor)
return tensor

View File

@@ -1,6 +1,6 @@
x-env:
&template
image: fcb_photo_review:1.14.6
image: fcb_photo_review:1.16.0
restart: always
x-review:
@@ -31,30 +31,12 @@ x-mask:
driver: 'nvidia'
services:
det_api:
<<: *template
build:
context: .
container_name: det_api
hostname: det_api
volumes:
- ./log:/app/log
- ./model:/app/model
# command: [ 'det_api.py' ]
deploy:
resources:
reservations:
devices:
- device_ids: [ '0' ]
capabilities: [ 'gpu' ]
driver: 'nvidia'
photo_review_1:
<<: *review_template
build:
context: .
container_name: photo_review_1
hostname: photo_review_1
depends_on:
- det_api
command: [ 'photo_review.py', '--clean', 'True' ]
photo_review_2:

View File

@@ -8,6 +8,7 @@ LOG_PATHS = [
f"log/{HOSTNAME}/ucloud",
f"log/{HOSTNAME}/error",
f"log/{HOSTNAME}/qr",
f"log/{HOSTNAME}/sql",
]
for path in LOG_PATHS:
if not os.path.exists(path):
@@ -74,6 +75,16 @@ LOGGING_CONFIG = {
'backupCount': 14,
'encoding': 'utf-8',
},
'sql': {
'class': 'logging.handlers.TimedRotatingFileHandler',
'level': 'INFO',
'formatter': 'standard',
'filename': f'log/{HOSTNAME}/sql/fcb_photo_review_sql.log',
'when': 'midnight',
'interval': 1,
'backupCount': 14,
'encoding': 'utf-8',
},
},
# loggers定义了日志记录器
@@ -98,5 +109,10 @@ LOGGING_CONFIG = {
'level': 'DEBUG',
'propagate': False,
},
'sql': {
'handlers': ['console', 'sql'],
'level': 'DEBUG',
'propagate': False,
},
},
}

View File

@@ -160,6 +160,8 @@ def get_mask_layout(image, name, id_card_num):
result += _find_boxes_by_keys(ID_CARD_NUM_KEYS)
return result
except MemoryError as e:
raise e
except Exception as e:
logging.error("涂抹时出错!", exc_info=e)
return result
@@ -197,6 +199,8 @@ def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
# 打开图片
image = image_util.read(img_url)
if image is None:
return False, image
original_image = image
is_masked, image = _mask(image, name, id_card_num, color)
if not is_masked:

View File

@@ -101,4 +101,14 @@ DISCHARGE_IE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA
COST_IE = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base', device_id=1,
task_path='model/cost_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
OCR = PaddleOCR(use_angle_cls=False, show_log=False, gpu_id=1, det_db_box_thresh=0.3)
OCR = PaddleOCR(
gpu_id=1,
use_angle_cls=False,
show_log=False,
det_db_thresh=0.1,
det_db_box_thresh=0.3,
det_limit_side_len=1248,
drop_score=0.3,
rec_model_dir='model/ocr/openatom_rec_repsvtr_ch_infer',
rec_algorithm='SVTR_LCNet',
)

View File

@@ -22,11 +22,11 @@ from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_E
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES, \
UPPERCASE_MEDICAL_EXPENSES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, DEPARTMENT_FILTER
from ucloud import ufile
from ucloud import ufile, BUCKET
from util import image_util, util, html_util
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_id, handle_age, parse_money, \
parse_hospital
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_age, parse_money, \
parse_hospital, handle_doctor, handle_text, handle_admission_id, handle_settlement_id
# 合并信息抽取结果
@@ -41,6 +41,7 @@ def ie_temp_image(ie, ocr, image):
cv2.imwrite(temp_file.name, image)
ie_result = []
ocr_pure_text = ''
try:
layout = util.get_ocr_layout(ocr, temp_file.name)
if not layout:
@@ -48,6 +49,11 @@ def ie_temp_image(ie, ocr, image):
ie_result = []
else:
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
for lay in layout:
ocr_pure_text += lay[1]
except MemoryError as e:
# 显存不足时应该抛出错误,让程序重启,同时释放显存
raise e
except Exception as e:
logging.error("信息抽取时出错", exc_info=e)
finally:
@@ -55,7 +61,7 @@ def ie_temp_image(ie, ocr, image):
os.remove(temp_file.name)
except Exception as e:
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
return ie_result
return ie_result, ocr_pure_text
# 关键信息提取
@@ -147,12 +153,16 @@ def get_better_image_from_qrcode(image, image_id, dpi=150):
# 关键信息提取
def information_extraction(ie, phrecs, identity):
result = {}
ocr_text = ''
for phrec in phrecs:
img_path = ufile.get_private_url(phrec.cfjaddress)
if not img_path:
continue
image = image_util.read(img_path)
if image is None:
# 图片可能因为某些原因获取不到
continue
# 尝试从二维码中获取高清图片
better_image, text = get_better_image_from_qrcode(image, phrec.cfjaddress)
@@ -165,7 +175,7 @@ def information_extraction(ie, phrecs, identity):
if text:
info_extract = ie(text)[0]
else:
info_extract = ie_temp_image(ie, OCR, image)
info_extract = ie_temp_image(ie, OCR, image)[0]
ie_result = {'result': info_extract, 'angle': '0'}
now = util.get_default_datetime()
@@ -198,10 +208,12 @@ def information_extraction(ie, phrecs, identity):
if split_result['img'] is None or split_result['img'].size == 0:
continue
rotated_img = image_util.rotate(split_result['img'], int(angles[0]))
ie_results = [{'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[0]}]
ie_temp_result = ie_temp_image(ie, OCR, rotated_img)
ocr_text += ie_temp_result[1]
ie_results = [{'result': ie_temp_result[0], 'angle': angles[0]}]
if not ie_results[0]['result'] or len(ie_results[0]['result']) < len(ie.kwargs.get('schema')):
rotated_img = image_util.rotate(split_result['img'], int(angles[1]))
ie_results.append({'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[1]})
ie_results.append({'result': ie_temp_image(ie, OCR, rotated_img)[0], 'angle': angles[1]})
now = util.get_default_datetime()
best_angle = ['0', 0]
for ie_result in ie_results:
@@ -231,13 +243,15 @@ def information_extraction(ie, phrecs, identity):
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
cv2.imwrite(temp_file.name, image)
try:
ufile.upload_file(phrec.cfjaddress, temp_file.name)
if img_angle != '0':
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'旋转图片[{phrec.cfjaddress}]替换成功,已旋转{img_angle}度。')
# 修正旋转角度
for zx_ie_result in zx_ie_results:
zx_ie_result.rotation_angle -= int(img_angle)
else:
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
ufile.upload_file(phrec.cfjaddress, temp_file.name)
logging.info(f'高清图片[{phrec.cfjaddress}]替换成功!')
except Exception as e:
logging.error(f'上传图片({phrec.cfjaddress})失败', exc_info=e)
@@ -247,8 +261,21 @@ def information_extraction(ie, phrecs, identity):
session = MysqlSession()
session.add_all(zx_ie_results)
session.commit()
session.close()
# # 添加清晰度测试
# if better_image is None:
# # 替换后图片默认清晰
# clarity_result = image_util.parse_clarity(image)
# unsharp_flag = 0 if (clarity_result[0] == 0 and clarity_result[1] >= 0.8) else 1
# update_clarity = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
# cfjaddress2=json.dumps(clarity_result),
# unsharp_flag=unsharp_flag,
# ))
# session.execute(update_clarity)
# session.commit()
# session.close()
result['ocr_text'] = ocr_text
return result
@@ -293,7 +320,7 @@ def save_or_update_ie(table, pk_phhd, data):
if db_data:
# 更新
db_data.update_time = now
db_data.creator = HOSTNAME
db_data.updater = HOSTNAME
for k, v in data.items():
setattr(db_data, k, v)
else:
@@ -379,8 +406,8 @@ def settlement_task(pk_phhd, settlement_list, identity):
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
"medical_insurance_type_str": handle_original_data(
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE)),
"admission_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
"settlement_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
"admission_id": handle_admission_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
"settlement_id": handle_settlement_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
}
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
@@ -394,6 +421,10 @@ def settlement_task(pk_phhd, settlement_list, identity):
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES))
settlement_data["medical_expenses_str"] = handle_original_data(parse_money_result[0])
settlement_data["medical_expenses"] = parse_money_result[1]
if not settlement_data["settlement_id"]:
# 如果没有结算单号就填住院号
settlement_data["settlement_id"] = settlement_data["admission_id"]
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
@@ -408,9 +439,10 @@ def discharge_task(pk_phhd, discharge_record, identity):
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
"doctor": handle_name(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
"admission_id": handle_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
"doctor": handle_doctor(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
"admission_id": handle_admission_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
"age": handle_age(get_best_value_in_keys(discharge_record_ie_result, AGE)),
"content": handle_text(discharge_record_ie_result['ocr_text']),
}
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
@@ -476,7 +508,8 @@ def cost_task(pk_phhd, cost_list, identity):
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES))
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES)),
"content": handle_text(cost_list_ie_result['ocr_text']),
}
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])

25
tool/batch_download.py Normal file
View File

@@ -0,0 +1,25 @@
# 批量下载图片
import pandas as pd
import requests
from ucloud import ufile
if __name__ == '__main__':
# 读取xlsx文件
file_path = r'D:\Echo\Downloads\Untitled.xlsx' # 将 your_file.xlsx 替换为你的文件路径
df = pd.read_excel(file_path)
# 循环读取 cfjaddress 列
for address in df['cfjaddress']:
img_url = ufile.get_private_url(address)
# 下载图片并保存到本地
response = requests.get(img_url)
# 检查请求是否成功
if response.status_code == 200:
# 定义保存图片的路径
with open(f'../img/{address}', 'wb') as file:
file.write(response.content)
print(f"{address}下载成功")
else:
print(f"{address}下载失败,状态码: {response.status_code}")

View File

@@ -0,0 +1,43 @@
import pandas as pd
from db import MysqlSession
from db.mysql import ZxPhrec
def check_unclarity():
file_path = r'D:\Echo\Downloads\unclare.xlsx' # 将 your_file.xlsx 替换为你的文件路径
df = pd.read_excel(file_path)
# 循环读取 pk_phrec 列
output_file_path = 'check_unclarity_result.txt'
with open(output_file_path, 'w') as f:
for pk_phrec in df['pk_phrec']:
session = MysqlSession()
phrec = session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag).filter(
ZxPhrec.pk_phrec == pk_phrec
).one_or_none()
session.close()
if phrec and phrec.unsharp_flag == 0:
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
def check_clarity():
file_path = r'D:\Echo\Downloads\unclare.xlsx'
df = pd.read_excel(file_path)
session = MysqlSession()
phrecs = (session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag)
.filter(ZxPhrec.pk_phrec >= 30810206)
.filter(ZxPhrec.pk_phrec <= 31782247)
.filter(ZxPhrec.unsharp_flag == 1)
.all())
session.close()
output_file_path = 'check_clarity_result.txt'
with open(output_file_path, 'w') as f:
for phrec in phrecs:
if phrec.pk_phrec not in df['pk_phrec'].to_list():
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
if __name__ == '__main__':
check_clarity()

17
tool/paddle2nb.py Normal file
View File

@@ -0,0 +1,17 @@
# paddle模型转nb格式推荐使用paddlelite2.10
from paddlelite.lite import Opt
if __name__ == '__main__':
# 1. 创建opt实例
opt = Opt()
# 2. 指定输入模型地址
opt.set_model_dir("./model")
# 3. 指定转化类型: arm、x86、opencl、npu
# 一般认为arm为cpuopencl为gpu
opt.set_valid_places("arm")
# 4. 指定模型转化类型: naive_buffer、protobuf
opt.set_model_type("naive_buffer")
# 4. 输出模型地址
opt.set_optimize_out("model")
# 5. 执行模型优化
opt.run()

View File

@@ -2,7 +2,7 @@ import logging
import re
from datetime import datetime
from util import util
from util import util, string_util
# 处理金额类数据
@@ -12,6 +12,7 @@ def handle_decimal(string):
string = re.sub(r'[^0-9.]', '', string)
if not string:
return ""
string = string_util.full_to_half(string)
if "." not in string:
if len(string) > 2:
result = string[:-2] + "." + string[-2:]
@@ -89,13 +90,17 @@ def handle_date(string):
def handle_hospital(string):
if not string:
return ""
return string[:255]
# 只允许汉字、数字
string = re.sub(r'[^⺀-鿿0-9]', '', string)
return string[:200]
def handle_department(string):
if not string:
return ""
return string[:255]
# 只允许汉字
string = re.sub(r'[^⺀-鿿]', '', string)
return string[:200]
def parse_department(string):
@@ -119,7 +124,13 @@ def parse_department(string):
def handle_name(string):
if not string:
return ""
return re.sub(r'[^⺀-鿿·]', '', string)[:30]
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
def handle_doctor(string):
if not string:
return ""
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
# 处理医保类型数据
@@ -150,18 +161,29 @@ def handle_original_data(string):
# 处理id类数据
def handle_id(string):
def handle_admission_id(string):
if not string:
return ""
# 只允许字母和数字
string = re.sub(r'[^0-9a-zA-Z]', '', string)
# 防止过长存入数据库失败
return string[:50]
return string[:20]
def handle_settlement_id(string):
if not string:
return ""
# 只允许字母和数字
string = re.sub(r'[^0-9a-zA-Z]', '', string)
# 防止过长存入数据库失败
return string[:30]
# 处理年龄类数据
def handle_age(string):
if not string:
return ""
string = string.split("")[0]
string = string_util.full_to_half(string.split("")[0])
num = re.sub(r'\D', '', string)
return num[-3:]
@@ -178,3 +200,9 @@ def parse_hospital(string):
split_hospitals = string_without_company.replace("医院", "医院 ")
result += split_hospitals.strip().split(" ")
return result
def handle_text(string):
if not string:
return ""
return string[:16383]

View File

@@ -72,14 +72,18 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
for i in range(math.ceil(height / step)):
offset = round(step * i)
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
if cropped_img.shape[0] > 0:
# 计算误差可能导致图片高度为0此时不添加
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
elif wh_ratio > ratio: # 横向过长
new_img_width = height * ratio
step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分
for i in range(math.ceil(width / step)):
offset = round(step * i)
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
if cropped_img.shape[1] > 0:
# 计算误差可能导致图片宽度为0此时不添加
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
else:
split_result.append({"img": image, "x_offset": 0, "y_offset": 0})
return split_result
@@ -247,3 +251,78 @@ def combined(img1, img2):
combined_img[:height1, :width1] = img1
combined_img[:height2, width1:width1 + width2] = img2
return combined_img
def parse_clarity(image):
"""
判断图片清晰度
:param image: 图片NumPy数组或文件路径
:return: 判断结果及置信度
"""
clarity_result = [1, 0]
model = PaddleClas(inference_model_dir=r"model/clas/clarity_assessment", use_gpu=True)
clas_result = model.predict(input_data=image)
try:
clas_result = next(clas_result)[0]
clarity_result = [clas_result["class_ids"][0], clas_result["scores"][0]]
except Exception as e:
logging.error("获取图片清晰度失败", exc_info=e)
return clarity_result
def is_photo_by_exif(exif_tags):
"""分析EXIF数据判断是否为照片"""
# 照片通常包含的EXIF标签
photo_tags = [
'FNumber', # 光圈
'ExposureTime', # 曝光时间
'ISOSpeedRatings', # ISO
'FocalLength', # 焦距
'LensModel', # 镜头型号
'GPSLatitude' # GPS位置信息
]
# 统计照片相关的EXIF标签数量
photo_tag_count = 0
if exif_tags:
for tag in photo_tags:
if tag in exif_tags:
photo_tag_count += 1
# 如果有2个以上照片相关的EXIF标签倾向于是照片
if photo_tag_count >= 2:
return True
# 不确定是照片返回False
return False
def is_screenshot_by_image_features(image):
"""分析图像特征判断是否为截图"""
# 定义边缘像素标准差阈值,小于此阈值则认为图片是截图
edge_std_threshold = 20.0
try:
# 检查边缘像素的一致性(截图边缘通常更整齐)
edge_pixels = []
# 取图像边缘10像素
edge_pixels.extend(image[:10, :].flatten()) # 顶部边缘
edge_pixels.extend(image[-10:, :].flatten()) # 底部边缘
edge_pixels.extend(image[:, :10].flatten()) # 左侧边缘
edge_pixels.extend(image[:, -10:].flatten()) # 右侧边缘
# 计算边缘像素的标准差(值越小说明越一致)
edge_std = numpy.std(edge_pixels)
logging.info(f"边缘像素标准差: {edge_std}")
return edge_std < edge_std_threshold
except Exception as e:
logging.error("图像特征分析失败", exc_info=e)
return False
def is_screenshot(image, exif_tags):
"""综合判断是否是截图"""
# 先检查EXIF数据
result_of_exif = is_photo_by_exif(exif_tags)
# 如果有明显的照片EXIF信息直接判断为照片
if result_of_exif:
return False
# 分析图像特征
return is_screenshot_by_image_features(image)

View File

@@ -1,3 +1,6 @@
import unicodedata
def blank(string):
"""
判断字符串是否为空或者纯空格
@@ -5,3 +8,23 @@ def blank(string):
:return: 字符串是否为空或者纯空格
"""
return not string or string.isspace()
def full_to_half(string):
"""
全角转半角
:param string: 字符串
:return: 半角字符串
"""
if not isinstance(string, str):
raise TypeError("全角转半角的输入必须是字符串类型")
if not string:
return string
half_string = ''.join([
unicodedata.normalize('NFKC', char) if unicodedata.east_asian_width(char) in ['F', 'W'] else char
for char in string
])
return half_string