Compare commits
61 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 97903b2722 | |||
| 670172e79e | |||
| d266c2828c | |||
| bc4c95c18c | |||
| af08078380 | |||
| 8984948107 | |||
| 843511b6f3 | |||
| a5d7da6536 | |||
| 99d555aba9 | |||
| 7ca5b9d908 | |||
| 5e5e35fd9f | |||
| d080b66ebf | |||
| 1e8ef432df | |||
| a6515e971b | |||
| e3fd3f618f | |||
| b6ae36a8ec | |||
| a99a615e22 | |||
| 5f645b5b4b | |||
| 1625f0294f | |||
| f19f8cbcae | |||
| ba3e23d185 | |||
| 0abf7abb5b | |||
| 47ac6aadbe | |||
| 09ede1af25 | |||
| e40d963bf5 | |||
| 34344edd29 | |||
| b387db1e08 | |||
| 88ca27928f | |||
| 109a5e9444 | |||
| ab5f78cc7b | |||
| 04358ee646 | |||
| a67c53f470 | |||
| cd604bc1eb | |||
| 0de9fc14b5 | |||
| 5287df4959 | |||
| 3e9c0c99b9 | |||
| a740f16e6b | |||
| b9606771cf | |||
| 110bc57abc | |||
| f965bc4289 | |||
| 8b6bf03d76 | |||
| 73536aea89 | |||
| 12f6554d8c | |||
| be94bc7f09 | |||
| 8307fcd549 | |||
| ab2dbf7c15 | |||
| dac7a1b5ce | |||
| 935fa26067 | |||
| 16ab4c78d5 | |||
| ed77b8ed82 | |||
| f4ec3b1eb4 | |||
| a93f68413c | |||
| 0a947274dc | |||
| f7540c4574 | |||
| 5e6a471954 | |||
| 96b8a06e6c | |||
| be27f753ba | |||
| 8ea5420520 | |||
| 9749577c5a | |||
| e0f6b82dad | |||
| 8223759fdf |
@@ -15,6 +15,7 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
COPY packages /app/packages
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
|
||||
&& python3 -m pip install --upgrade pip \
|
||||
&& pip install --no-cache-dir -r requirements.txt \
|
||||
&& pip uninstall -y onnxruntime onnxruntime-gpu \
|
||||
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
|
||||
33
Dockerfile.dev
Normal file
33
Dockerfile.dev
Normal file
@@ -0,0 +1,33 @@
|
||||
# 使用官方的paddle镜像作为基础
|
||||
FROM ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
# 设置时区
|
||||
TZ=Asia/Shanghai \
|
||||
# 设置pip镜像地址,加快安装速度
|
||||
PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 安装language-pack-en和openssh-server
|
||||
RUN apt update && \
|
||||
apt install -y language-pack-en && \
|
||||
apt install -y openssh-server
|
||||
|
||||
# 配置SSH服务
|
||||
RUN mkdir /var/run/sshd && \
|
||||
# 设置root密码,可根据需要修改
|
||||
echo 'root:fcb0102' | chpasswd && \
|
||||
# 允许root登录SSH
|
||||
sed -i 's/#PermitRootLogin prohibit-password/PermitRootLogin yes/' /etc/ssh/sshd_config
|
||||
|
||||
# 将当前目录内容复制到容器的/app内
|
||||
COPY . /app
|
||||
|
||||
# 暴露22端口
|
||||
EXPOSE 22
|
||||
|
||||
# 启动SSH服务
|
||||
CMD ["/usr/sbin/sshd", "-D"]
|
||||
@@ -126,3 +126,7 @@ bash update.sh
|
||||
2. 新增扭曲矫正功能
|
||||
21. 版本号:1.14.0
|
||||
1. 新增二维码识别替换高清图片功能
|
||||
22. 版本号:1.15.0
|
||||
1. 新增图片清晰度测试
|
||||
23. 版本号:1.16.0
|
||||
1. 更新paddle框架至3.0
|
||||
@@ -14,7 +14,7 @@ from ucloud import ufile
|
||||
from util import image_util
|
||||
|
||||
|
||||
def check_ie_result(pk_phhd):
|
||||
def check_ie_result(pk_phhd, need_to_annotation=True):
|
||||
os.makedirs(f"./check_result/{pk_phhd}", exist_ok=True)
|
||||
json_result = {"pk_phhd": pk_phhd}
|
||||
session = MysqlSession()
|
||||
@@ -46,45 +46,51 @@ def check_ie_result(pk_phhd):
|
||||
ZxPhrec.pk_phhd == pk_phhd).all()
|
||||
for phrec in phrecs:
|
||||
img_name = phrec.cfjaddress
|
||||
img_path = ufile.get_private_url(img_name)
|
||||
img_path = ufile.get_private_url(img_name, "drg2015")
|
||||
if not img_path:
|
||||
img_path = ufile.get_private_url(img_name)
|
||||
|
||||
response = requests.get(img_path)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
font_size = image.width * image.height / 200000
|
||||
font = ImageFont.truetype("./font/simfang.ttf", size=font_size)
|
||||
if need_to_annotation:
|
||||
font_size = image.width * image.height / 200000
|
||||
font = ImageFont.truetype("./font/simfang.ttf", size=font_size)
|
||||
|
||||
ocr = session.query(ZxIeResult.id, ZxIeResult.content, ZxIeResult.rotation_angle, ZxIeResult.x_offset,
|
||||
ZxIeResult.y_offset).filter(ZxIeResult.pk_phrec == phrec.pk_phrec).all()
|
||||
if not ocr:
|
||||
ocr = session.query(ZxIeResult.id, ZxIeResult.content, ZxIeResult.rotation_angle, ZxIeResult.x_offset,
|
||||
ZxIeResult.y_offset).filter(ZxIeResult.pk_phrec == phrec.pk_phrec).all()
|
||||
if not ocr:
|
||||
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
|
||||
|
||||
for _, group_results in groupby(ocr, key=lambda x: x.id):
|
||||
draw = ImageDraw.Draw(image)
|
||||
for ocr_item in group_results:
|
||||
result = json.loads(ocr_item.content)
|
||||
rotation_angle = ocr_item.rotation_angle
|
||||
x_offset = ocr_item.x_offset
|
||||
y_offset = ocr_item.y_offset
|
||||
for key in result:
|
||||
for value in result[key]:
|
||||
box = value["bbox"][0]
|
||||
|
||||
if rotation_angle:
|
||||
box = image_util.invert_rotate_rectangle(box, (image.width / 2, image.height / 2),
|
||||
rotation_angle)
|
||||
if x_offset:
|
||||
box[0] += x_offset
|
||||
box[2] += x_offset
|
||||
if y_offset:
|
||||
box[1] += y_offset
|
||||
box[3] += y_offset
|
||||
|
||||
draw.rectangle(box, outline="red", width=2) # 绘制矩形
|
||||
draw.text((box[0], box[1] - font_size), key, fill="blue", font=font) # 在矩形上方绘制文本
|
||||
draw.text((box[0], box[3]), value["text"], fill="blue", font=font) # 在矩形下方绘制文本
|
||||
os.makedirs(f"./check_result/{pk_phhd}/{ocr_item.id}", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/{ocr_item.id}/{img_name}")
|
||||
else:
|
||||
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
|
||||
|
||||
for _, group_results in groupby(ocr, key=lambda x: x.id):
|
||||
draw = ImageDraw.Draw(image)
|
||||
for ocr_item in group_results:
|
||||
result = json.loads(ocr_item.content)
|
||||
rotation_angle = ocr_item.rotation_angle
|
||||
x_offset = ocr_item.x_offset
|
||||
y_offset = ocr_item.y_offset
|
||||
for key in result:
|
||||
for value in result[key]:
|
||||
box = value["bbox"][0]
|
||||
|
||||
if rotation_angle:
|
||||
box = image_util.invert_rotate_rectangle(box, (image.width / 2, image.height / 2),
|
||||
rotation_angle)
|
||||
if x_offset:
|
||||
box[0] += x_offset
|
||||
box[2] += x_offset
|
||||
if y_offset:
|
||||
box[1] += y_offset
|
||||
box[3] += y_offset
|
||||
|
||||
draw.rectangle(box, outline="red", width=2) # 绘制矩形
|
||||
draw.text((box[0], box[1] - font_size), key, fill="blue", font=font) # 在矩形上方绘制文本
|
||||
draw.text((box[0], box[3]), value["text"], fill="blue", font=font) # 在矩形下方绘制文本
|
||||
os.makedirs(f"./check_result/{pk_phhd}/{ocr_item.id}", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/{ocr_item.id}/{img_name}")
|
||||
session.close()
|
||||
|
||||
# 自定义JSON处理器
|
||||
@@ -99,4 +105,4 @@ def check_ie_result(pk_phhd):
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_ie_result(0)
|
||||
check_ie_result(5640504)
|
||||
|
||||
@@ -19,5 +19,5 @@ DB_URL = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}'
|
||||
SHOW_SQL = False
|
||||
|
||||
Engine = create_engine(DB_URL, echo=SHOW_SQL)
|
||||
Base = declarative_base(Engine)
|
||||
Base = declarative_base()
|
||||
MysqlSession = sessionmaker(bind=Engine)
|
||||
|
||||
27
db/mysql.py
27
db/mysql.py
@@ -1,5 +1,5 @@
|
||||
# coding: utf-8
|
||||
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary
|
||||
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary, Text
|
||||
from sqlalchemy.dialects.mysql import BIT, CHAR, INTEGER, TINYINT, VARCHAR
|
||||
|
||||
from db import Base
|
||||
@@ -56,13 +56,16 @@ class ZxIeCost(Base):
|
||||
|
||||
pk_ie_cost = Column(INTEGER(11), primary_key=True, comment='费用明细信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
content = Column(Text, comment='详细内容')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
discharge_date = Column(Date, comment='出院日期')
|
||||
medical_expenses_str = Column(String(255), comment='费用总额字符串')
|
||||
medical_expenses = Column(DECIMAL(18, 2), comment='费用总额')
|
||||
page_nums = Column(String(255), comment='页码')
|
||||
page_count = Column(TINYINT(4), comment='页数')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
@@ -94,19 +97,19 @@ class ZxIeDischarge(Base):
|
||||
|
||||
pk_ie_discharge = Column(INTEGER(11), primary_key=True, comment='出院记录信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
content = Column(String(5000), comment='详细内容')
|
||||
hospital = Column(String(255), comment='医院')
|
||||
content = Column(Text, comment='详细内容')
|
||||
hospital = Column(String(200), comment='医院')
|
||||
pk_yljg = Column(INTEGER(11), comment='医院主键')
|
||||
department = Column(String(255), comment='科室')
|
||||
department = Column(String(200), comment='科室')
|
||||
pk_ylks = Column(INTEGER(11), comment='科室主键')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
age = Column(INTEGER(3), comment='年龄')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
discharge_date = Column(Date, comment='出院日期')
|
||||
doctor = Column(String(30), comment='主治医生')
|
||||
admission_id = Column(String(50), comment='住院号')
|
||||
doctor = Column(String(20), comment='主治医生')
|
||||
admission_id = Column(String(20), comment='住院号')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
@@ -138,7 +141,7 @@ class ZxIeSettlement(Base):
|
||||
|
||||
pk_ie_settlement = Column(INTEGER(11), primary_key=True, comment='结算清单信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
@@ -152,9 +155,9 @@ class ZxIeSettlement(Base):
|
||||
personal_funded_amount_str = Column(String(255), comment='自费金额字符串')
|
||||
personal_funded_amount = Column(DECIMAL(18, 2), comment='自费金额')
|
||||
medical_insurance_type_str = Column(String(255), comment='医保类型字符串')
|
||||
medical_insurance_type = Column(String(40), comment='医保类型')
|
||||
admission_id = Column(String(50), comment='住院号')
|
||||
settlement_id = Column(String(50), comment='医保结算单号码')
|
||||
medical_insurance_type = Column(String(10), comment='医保类型')
|
||||
admission_id = Column(String(20), comment='住院号')
|
||||
settlement_id = Column(String(30), comment='医保结算单号码')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
|
||||
74
delete_deprecated_data.py
Normal file
74
delete_deprecated_data.py
Normal file
@@ -0,0 +1,74 @@
|
||||
# 删除本地数据库中的过期数据
|
||||
import logging.config
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhhd, ZxIeCost, ZxIeDischarge, ZxIeResult, ZxIeSettlement, ZxPhrec
|
||||
from log import LOGGING_CONFIG
|
||||
|
||||
# 过期时间(不建议小于1个月)
|
||||
EXPIRATION_DAYS = 183
|
||||
# 批量删除数量(最好在1000~10000之间)
|
||||
BATCH_SIZE = 5000
|
||||
# 数据库会话对象
|
||||
session = None
|
||||
|
||||
|
||||
def batch_delete_by_pk_phhd(model, pk_phhds):
|
||||
"""
|
||||
批量删除指定模型中主键在指定列表中的数据
|
||||
|
||||
参数:
|
||||
model:SQLAlchemy模型类(对应数据库表)
|
||||
pk_phhds:待删除的主键值列表
|
||||
|
||||
返回:
|
||||
删除的记录数量
|
||||
"""
|
||||
delete_count = (
|
||||
session.query(model)
|
||||
.filter(model.pk_phhd.in_(pk_phhds))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logging.getLogger("sql").info(f"{model.__tablename__}成功删除{delete_count}条数据")
|
||||
return delete_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
deadline = datetime.now() - timedelta(days=EXPIRATION_DAYS)
|
||||
double_deadline = deadline - timedelta(days=EXPIRATION_DAYS)
|
||||
session = MysqlSession()
|
||||
try:
|
||||
while 1:
|
||||
# 已经走完所有流程的案子,超过过期时间后删除
|
||||
phhds = (session.query(ZxPhhd.pk_phhd)
|
||||
.filter(ZxPhhd.paint_flag == "9")
|
||||
.filter(ZxPhhd.billdate < deadline)
|
||||
.limit(BATCH_SIZE)
|
||||
.all())
|
||||
if not phhds or len(phhds) <= 0:
|
||||
# 没有通过审核,可能会重拍补拍上传的案子,超过两倍过期时间后删除
|
||||
phhds = (session.query(ZxPhhd.pk_phhd)
|
||||
.filter(ZxPhhd.exsuccess_flag == "9")
|
||||
.filter(ZxPhhd.paint_flag == "0")
|
||||
.filter(ZxPhhd.billdate < double_deadline)
|
||||
.limit(BATCH_SIZE)
|
||||
.all())
|
||||
if not phhds or len(phhds) <= 0:
|
||||
# 没有符合条件的数据,退出循环
|
||||
break
|
||||
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
|
||||
logging.getLogger("sql").info(f"过期的pk_phhd有{','.join(map(str, pk_phhd_values))}")
|
||||
batch_delete_by_pk_phhd(ZxPhrec, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeResult, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeSettlement, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeDischarge, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeCost, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxPhhd, pk_phhd_values)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logging.getLogger('error').error('过期数据删除失败!', exc_info=e)
|
||||
finally:
|
||||
session.close()
|
||||
167
doc_dewarp/.gitignore
vendored
167
doc_dewarp/.gitignore
vendored
@@ -1,167 +0,0 @@
|
||||
input/
|
||||
output/
|
||||
|
||||
runs/
|
||||
*_dump/
|
||||
*_log/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
@@ -1,34 +0,0 @@
|
||||
repos:
|
||||
# Common hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
||||
rev: v1.5.1
|
||||
hooks:
|
||||
- id: remove-crlf
|
||||
- id: remove-tabs
|
||||
name: Tabs remover (Python)
|
||||
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
args: [--whitespaces-count, '4']
|
||||
# For Python files
|
||||
- repo: https://github.com/psf/black.git
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.0.272
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --no-cache]
|
||||
@@ -1,398 +0,0 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from .extractor import BasicEncoder
|
||||
from .position_encoding import build_position_encoding
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class attnLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead=8,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn_list = nn.LayerList(
|
||||
[
|
||||
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
|
||||
for i in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p=dropout)
|
||||
self.dropout2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
|
||||
)
|
||||
self.dropout3 = nn.Dropout(p=dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, pos)
|
||||
tgt2 = self.self_attn(
|
||||
q.transpose((1, 0, 2)),
|
||||
k.transpose((1, 0, 2)),
|
||||
value=tgt.transpose((1, 0, 2)),
|
||||
attn_mask=tgt_mask,
|
||||
)
|
||||
tgt2 = tgt2.transpose((1, 0, 2))
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
|
||||
memory_list,
|
||||
self.multihead_attn_list,
|
||||
self.norm2_list,
|
||||
self.dropout2_list,
|
||||
memory_pos,
|
||||
):
|
||||
tgt2 = multihead_attn(
|
||||
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
|
||||
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
|
||||
value=memory.transpose((1, 0, 2)),
|
||||
attn_mask=memory_mask,
|
||||
).transpose((1, 0, 2))
|
||||
|
||||
tgt = tgt + dropout2(tgt2)
|
||||
tgt = norm2(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
|
||||
q = k = self.with_pos_embed(tgt2, pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, pos),
|
||||
key=self.with_pos_embed(memory, memory_pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransDecoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransDecoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return query_embed
|
||||
|
||||
|
||||
class TransEncoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransEncoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class FlowHead(nn.Layer):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class UpdateBlock(nn.Layer):
|
||||
def __init__(self, hidden_dim: int = 128):
|
||||
super(UpdateBlock, self).__init__()
|
||||
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2D(hidden_dim, 256, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2D(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, image, coords):
|
||||
mask = 0.25 * self.mask(image)
|
||||
dflow = self.flow_head(image)
|
||||
coords = coords + dflow
|
||||
return mask, coords
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
|
||||
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
|
||||
return coords[None].tile([batch, 1, 1, 1])
|
||||
|
||||
|
||||
def upflow8(flow, mode="bilinear"):
|
||||
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Layer):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
|
||||
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
||||
patch_size = (
|
||||
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||
)
|
||||
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
|
||||
self.proj = nn.Conv2D(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
||||
)
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
weight_init_(m, "trunc_normal_", std=0.02)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
weight_init_(m, "Constant", value=1.0)
|
||||
elif isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2)
|
||||
|
||||
perm = list(range(x.ndim))
|
||||
perm[1] = 2
|
||||
perm[2] = 1
|
||||
x = x.transpose(perm=perm)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class GeoTr(nn.Layer):
|
||||
def __init__(self):
|
||||
super(GeoTr, self).__init__()
|
||||
|
||||
self.hidden_dim = hdim = 256
|
||||
|
||||
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
|
||||
|
||||
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.encoder_block:
|
||||
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
|
||||
|
||||
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
|
||||
for i in self.down_layer:
|
||||
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
|
||||
|
||||
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.decoder_block:
|
||||
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
|
||||
|
||||
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
|
||||
for i in self.up_layer:
|
||||
self.__setattr__(
|
||||
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||
)
|
||||
|
||||
self.query_embed = nn.Embedding(81, self.hidden_dim)
|
||||
|
||||
self.update_block = UpdateBlock(self.hidden_dim)
|
||||
|
||||
def initialize_flow(self, img):
|
||||
N, _, H, W = img.shape
|
||||
coodslar = coords_grid(N, H, W)
|
||||
coords0 = coords_grid(N, H // 8, W // 8)
|
||||
coords1 = coords_grid(N, H // 8, W // 8)
|
||||
return coodslar, coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
N, _, H, W = flow.shape
|
||||
|
||||
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
|
||||
mask = F.softmax(mask, axis=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
|
||||
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
|
||||
|
||||
up_flow = paddle.sum(mask * up_flow, axis=2)
|
||||
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
|
||||
|
||||
return up_flow.reshape([N, 2, 8 * H, 8 * W])
|
||||
|
||||
def forward(self, image):
|
||||
fmap = self.fnet(image)
|
||||
fmap = F.relu(fmap)
|
||||
|
||||
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
|
||||
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
|
||||
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
|
||||
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
|
||||
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
|
||||
|
||||
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
|
||||
|
||||
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
|
||||
fmap3du_ = (
|
||||
self.__getattr__(self.up_layer[0])(fmap3d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
|
||||
fmap2du_ = (
|
||||
self.__getattr__(self.up_layer[1])(fmap2d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
|
||||
|
||||
coodslar, coords0, coords1 = self.initialize_flow(image)
|
||||
coords1 = coords1.detach()
|
||||
mask, coords1 = self.update_block(fmap_out, coords1)
|
||||
flow_up = self.upsample_flow(coords1 - coords0, mask)
|
||||
bm_up = coodslar + flow_up
|
||||
|
||||
return bm_up
|
||||
@@ -1,73 +0,0 @@
|
||||
# DocTrPP: DocTr++ in PaddlePaddle
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
|
||||
|
||||

|
||||
|
||||
## Requirements
|
||||
|
||||
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
|
||||
|
||||
## Training
|
||||
|
||||
1. Data Preparation
|
||||
|
||||
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
|
||||
|
||||
2. Training
|
||||
|
||||
```shell
|
||||
sh train.sh
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
3. Load Trained Model and Continue Training
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--resume "runs/train/DocTr++/weights/last.ckpt" \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
## Test and Inference
|
||||
|
||||
Test the dewarp result on a single image:
|
||||
|
||||
```shell
|
||||
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
|
||||
```
|
||||

|
||||
|
||||
## Export to onnx
|
||||
|
||||
```
|
||||
pip install paddle2onnx
|
||||
|
||||
python export.py -m ./best.ckpt --format onnx
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).
|
||||
@@ -1,4 +0,0 @@
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
|
||||
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])
|
||||
@@ -1,133 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d/",
|
||||
help="Path to the downloaded dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="output.png",
|
||||
help="Output filename for the image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
root = os.path.expanduser(args.data_root)
|
||||
folder = args.folder
|
||||
dirname = os.path.join(root, "img", str(folder))
|
||||
|
||||
choices = [f for f in os.listdir(dirname) if "png" in f]
|
||||
fname = random.choice(choices)
|
||||
|
||||
# Read Image
|
||||
img_path = os.path.join(dirname, fname)
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# Read 3D Coords
|
||||
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
# scale wc
|
||||
# value obtained from the entire dataset
|
||||
xmx, xmn, ymx, ymn, zmx, zmn = (
|
||||
1.2539363,
|
||||
-1.2442188,
|
||||
1.2396319,
|
||||
-1.2289206,
|
||||
0.6436657,
|
||||
-0.67492497,
|
||||
)
|
||||
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
|
||||
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
|
||||
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
|
||||
|
||||
# Read Backward Map
|
||||
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
# Read UV Map
|
||||
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
|
||||
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Depth Map
|
||||
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
|
||||
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
|
||||
# do some clipping and scaling to display it
|
||||
dmap[dmap > 30.0] = 30
|
||||
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
|
||||
|
||||
# Read Normal Map
|
||||
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
|
||||
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Albedo
|
||||
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
|
||||
alb = cv2.imread(alb_path)
|
||||
|
||||
# Read Checkerboard Image
|
||||
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
|
||||
recon = cv2.imread(recon_path)
|
||||
|
||||
# Display image and GTs
|
||||
|
||||
# use the backward mapping to dewarp the image
|
||||
# scale bm to -1.0 to 1.0
|
||||
bm_ = bm / np.array([448, 448])
|
||||
bm_ = (bm_ - 0.5) * 2
|
||||
bm_ = np.reshape(bm_, (1, 448, 448, 2))
|
||||
bm_ = paddle.to_tensor(bm_, dtype="float32")
|
||||
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
|
||||
img_ = np.expand_dims(img_, 0)
|
||||
img_ = paddle.to_tensor(img_, dtype="float32")
|
||||
uw = F.grid_sample(img_, bm_)
|
||||
uw = uw[0].numpy().transpose((1, 2, 0))
|
||||
|
||||
f, axrr = plt.subplots(2, 5)
|
||||
for ax in axrr:
|
||||
for a in ax:
|
||||
a.set_xticks([])
|
||||
a.set_yticks([])
|
||||
|
||||
axrr[0][0].imshow(img)
|
||||
axrr[0][0].title.set_text("image")
|
||||
axrr[0][1].imshow(wc)
|
||||
axrr[0][1].title.set_text("3D coords")
|
||||
axrr[0][2].imshow(bm[:, :, 0])
|
||||
axrr[0][2].title.set_text("bm 0")
|
||||
axrr[0][3].imshow(bm[:, :, 1])
|
||||
axrr[0][3].title.set_text("bm 1")
|
||||
if uv is None:
|
||||
uv = np.zeros_like(img)
|
||||
axrr[0][4].imshow(uv)
|
||||
axrr[0][4].title.set_text("uv map")
|
||||
axrr[1][0].imshow(dmap)
|
||||
axrr[1][0].title.set_text("depth map")
|
||||
axrr[1][1].imshow(norm)
|
||||
axrr[1][1].title.set_text("normal map")
|
||||
axrr[1][2].imshow(alb)
|
||||
axrr[1][2].title.set_text("albedo")
|
||||
axrr[1][3].imshow(recon)
|
||||
axrr[1][3].title.set_text("checkerboard")
|
||||
axrr[1][4].imshow(uw)
|
||||
axrr[1][4].title.set_text("gt unwarped")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.output)
|
||||
@@ -1,21 +0,0 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from . import DOC_TR
|
||||
from .utils import to_tensor, to_image
|
||||
|
||||
|
||||
def dewarp_image(image):
|
||||
img = cv2.resize(image, (288, 288)).astype(np.float32)
|
||||
y = to_tensor(image)
|
||||
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
bm = DOC_TR.run(None, {"image": img[None,]})[0]
|
||||
bm = paddle.to_tensor(bm)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
return to_image(out)
|
||||
@@ -1,161 +0,0 @@
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 76 KiB |
@@ -1,129 +0,0 @@
|
||||
import collections
|
||||
import os
|
||||
import random
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import io
|
||||
|
||||
# Set random seed
|
||||
random.seed(12345678)
|
||||
|
||||
|
||||
class Doc3dDataset(io.Dataset):
|
||||
def __init__(self, root, split="train", is_augment=False, image_size=512):
|
||||
self.root = os.path.expanduser(root)
|
||||
|
||||
self.split = split
|
||||
self.is_augment = is_augment
|
||||
|
||||
self.files = collections.defaultdict(list)
|
||||
|
||||
self.image_size = (
|
||||
image_size if isinstance(image_size, tuple) else (image_size, image_size)
|
||||
)
|
||||
|
||||
# Augmentation
|
||||
self.augmentation = A.Compose(
|
||||
[
|
||||
A.ColorJitter(),
|
||||
]
|
||||
)
|
||||
|
||||
for split in ["train", "val"]:
|
||||
path = os.path.join(self.root, split + ".txt")
|
||||
file_list = []
|
||||
with open(path, "r") as file:
|
||||
file_list = [file_id.rstrip() for file_id in file.readlines()]
|
||||
self.files[split] = file_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files[self.split])
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_name = self.files[self.split][index]
|
||||
|
||||
# Read image
|
||||
image_path = os.path.join(self.root, "img", image_name + ".png")
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
# Read 3D Coordinates
|
||||
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read backward map
|
||||
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
image, bm = self.transform(wc, bm, image)
|
||||
|
||||
return image, bm
|
||||
|
||||
def tight_crop(self, wc: np.ndarray):
|
||||
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_size = mask.shape
|
||||
[y, x] = mask.nonzero()
|
||||
min_x = min(x)
|
||||
max_x = max(x)
|
||||
min_y = min(y)
|
||||
max_y = max(y)
|
||||
|
||||
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
|
||||
s = 10
|
||||
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
|
||||
|
||||
cx1 = random.randint(0, 2 * s)
|
||||
cx2 = random.randint(0, 2 * s) + 1
|
||||
cy1 = random.randint(0, 2 * s)
|
||||
cy2 = random.randint(0, 2 * s) + 1
|
||||
|
||||
wc = wc[cy1:-cy2, cx1:-cx2, :]
|
||||
|
||||
top: int = min_y - s + cy1
|
||||
bottom: int = mask_size[0] - max_y - s + cy2
|
||||
left: int = min_x - s + cx1
|
||||
right: int = mask_size[1] - max_x - s + cx2
|
||||
|
||||
top = np.clip(top, 0, mask_size[0])
|
||||
bottom = np.clip(bottom, 1, mask_size[0] - 1)
|
||||
left = np.clip(left, 0, mask_size[1])
|
||||
right = np.clip(right, 1, mask_size[1] - 1)
|
||||
|
||||
return wc, top, bottom, left, right
|
||||
|
||||
def transform(self, wc, bm, img):
|
||||
wc, top, bottom, left, right = self.tight_crop(wc)
|
||||
|
||||
img = img[top:-bottom, left:-right, :]
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.is_augment:
|
||||
img = self.augmentation(image=img)["image"]
|
||||
|
||||
# resize image
|
||||
img = cv2.resize(img, self.image_size)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
|
||||
# resize bm
|
||||
bm = bm.astype(np.float32)
|
||||
bm[:, :, 1] = bm[:, :, 1] - top
|
||||
bm[:, :, 0] = bm[:, :, 0] - left
|
||||
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
|
||||
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
|
||||
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
|
||||
bm0 = bm0 * self.image_size[0]
|
||||
bm1 = bm1 * self.image_size[1]
|
||||
|
||||
bm = np.stack([bm0, bm1], axis=-1)
|
||||
|
||||
img = paddle.to_tensor(img).astype(dtype="float32")
|
||||
bm = paddle.to_tensor(bm).astype(dtype="float32")
|
||||
|
||||
return img, bm
|
||||
@@ -1,66 +0,0 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def export(args):
|
||||
model_path = args.model
|
||||
imgsz = args.imgsz
|
||||
format = args.format
|
||||
|
||||
model = GeoTr()
|
||||
checkpoint = paddle.load(model_path)
|
||||
model.set_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
|
||||
dirname = os.path.dirname(model_path)
|
||||
if format == "static" or format == "onnx":
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
|
||||
],
|
||||
full_graph=True,
|
||||
)
|
||||
paddle.jit.save(model, os.path.join(dirname, "model"))
|
||||
|
||||
if format == "onnx":
|
||||
onnx_path = os.path.join(dirname, "model.onnx")
|
||||
os.system(
|
||||
f"paddle2onnx --model_dir {dirname}"
|
||||
" --model_filename model.pdmodel"
|
||||
" --params_filename model.pdiparams"
|
||||
f" --save_file {onnx_path}"
|
||||
" --opset_version 11"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="export model")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--imgsz", type=int, default=288, help="The size of input image"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="static",
|
||||
help="The format of exported model, which can be static or onnx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export(args)
|
||||
@@ -1,110 +0,0 @@
|
||||
import paddle.nn as nn
|
||||
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""Residual Block with custom normalization."""
|
||||
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
if norm_fn == "group":
|
||||
num_groups = planes // 8
|
||||
self.norm1 = nn.GroupNorm(num_groups, planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups, planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups, planes)
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(planes)
|
||||
self.norm2 = nn.BatchNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2D(planes)
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(planes)
|
||||
self.norm2 = nn.InstanceNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2D(planes)
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Layer):
|
||||
"""Basic Encoder with custom normalization."""
|
||||
|
||||
def __init__(self, output_dim=128, norm_fn="batch"):
|
||||
super(BasicEncoder, self).__init__()
|
||||
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(8, 64)
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(64)
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(64)
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(128, stride=2)
|
||||
self.layer3 = self._make_layer(192, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2D(192, output_dim, 1)
|
||||
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
|
||||
weight_init_(m, "Constant", value=1, bias_value=0.0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = layer1, layer2
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
@@ -1,48 +0,0 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import paddle.optimizer as optim
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
|
||||
"""
|
||||
Plot the learning rate scheduler
|
||||
"""
|
||||
|
||||
optimizer = copy.copy(optimizer)
|
||||
scheduler = copy.copy(scheduler)
|
||||
|
||||
lr = []
|
||||
for _ in range(epochs):
|
||||
for _ in range(30):
|
||||
lr.append(scheduler.get_lr())
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch = [float(i) / 30.0 for i in range(len(lr))]
|
||||
|
||||
plt.figure()
|
||||
plt.plot(epoch, lr, ".-", label="Learning Rate")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("Learning Rate")
|
||||
plt.title("Learning Rate Scheduler")
|
||||
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = GeoTr()
|
||||
|
||||
schaduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=1e-4,
|
||||
total_steps=1950,
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=1e-4 / 2.5e5,
|
||||
)
|
||||
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
|
||||
plot_lr_scheduler(
|
||||
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
|
||||
)
|
||||
@@ -1,124 +0,0 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.initializer as init
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Layer):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, mask):
|
||||
assert mask is not None
|
||||
|
||||
y_embed = mask.cumsum(axis=1, dtype="float32")
|
||||
x_embed = mask.cumsum(axis=2, dtype="float32")
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-06
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
|
||||
pos_x = paddle.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
pos_y = paddle.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
|
||||
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Layer):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_Constant = init.Uniform()
|
||||
init_Constant(self.row_embed.weight)
|
||||
init_Constant(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
i = paddle.arange(end=w)
|
||||
j = paddle.arange(end=h)
|
||||
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
|
||||
pos = (
|
||||
paddle.concat(
|
||||
[
|
||||
x_emb.unsqueeze(0).tile([h, 1, 1]),
|
||||
y_emb.unsqueeze(1).tile([1, w, 1]),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
.transpose([2, 0, 1])
|
||||
.unsqueeze(0)
|
||||
.tile([x.shape[0], 1, 1, 1])
|
||||
)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
|
||||
N_steps = hidden_dim // 2
|
||||
|
||||
if position_embedding in ("v2", "sine"):
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding}")
|
||||
return position_embedding
|
||||
@@ -1,69 +0,0 @@
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image, to_tensor
|
||||
|
||||
|
||||
def run(args):
|
||||
image_path = args.image
|
||||
model_path = args.model
|
||||
output_path = args.output
|
||||
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
model = GeoTr()
|
||||
model.set_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
img_org = cv2.imread(image_path)
|
||||
img = cv2.resize(img_org, (288, 288))
|
||||
x = to_tensor(img)
|
||||
y = to_tensor(img_org)
|
||||
bm = model(x)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
out_image = to_image(out)
|
||||
cv2.imwrite(output_path, out_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="predict")
|
||||
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"-i",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of image",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
run(args)
|
||||
@@ -1,7 +0,0 @@
|
||||
hdf5storage
|
||||
loguru
|
||||
numpy
|
||||
scipy
|
||||
opencv-python
|
||||
matplotlib
|
||||
albumentations
|
||||
@@ -1,57 +0,0 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
random.seed(1234567)
|
||||
|
||||
|
||||
def run(args):
|
||||
data_root = os.path.expanduser(args.data_root)
|
||||
ratio = args.train_ratio
|
||||
|
||||
data_path = os.path.join(data_root, "img", "*", "*.png")
|
||||
img_list = glob.glob(data_path, recursive=True)
|
||||
sorted(img_list)
|
||||
random.shuffle(img_list)
|
||||
|
||||
train_size = int(len(img_list) * ratio)
|
||||
|
||||
train_text_path = os.path.join(data_root, "train.txt")
|
||||
with open(train_text_path, "w") as file:
|
||||
for item in img_list[:train_size]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
val_text_path = os.path.join(data_root, "val.txt")
|
||||
with open(val_text_path, "w") as file:
|
||||
for item in img_list[train_size:]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
logger.info(f"TRAIN LABEL: {train_text_path}")
|
||||
logger.info(f"VAL LABEL: {val_text_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="Data path to load data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(args)
|
||||
|
||||
run(args)
|
||||
@@ -1,381 +0,0 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.optimizer as optim
|
||||
from loguru import logger
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import functional as F
|
||||
from paddle_msssim import ms_ssim, ssim
|
||||
|
||||
from doc3d_dataset import Doc3dDataset
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0]
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
if deterministic:
|
||||
os.environ["FLAGS_cudnn_deterministic"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
|
||||
# i.e. colorstr('blue', 'hello world')
|
||||
|
||||
*args, string = (
|
||||
input if len(input) > 1 else ("blue", "bold", input[0])
|
||||
) # color arguments, string
|
||||
|
||||
colors = {
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",
|
||||
}
|
||||
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
# Print function arguments (optional args dict)
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
# Increment file or directory path,
|
||||
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (
|
||||
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
)
|
||||
|
||||
for n in range(2, 9999):
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def train(args):
|
||||
save_dir = Path(args.save_dir)
|
||||
|
||||
use_vdl = args.use_vdl
|
||||
if use_vdl:
|
||||
from visualdl import LogWriter
|
||||
|
||||
log_dir = save_dir / "vdl"
|
||||
vdl_writer = LogWriter(str(log_dir))
|
||||
|
||||
# Directories
|
||||
weights_dir = save_dir / "weights"
|
||||
weights_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last = weights_dir / "last.ckpt"
|
||||
best = weights_dir / "best.ckpt"
|
||||
|
||||
# Hyperparameters
|
||||
|
||||
# Config
|
||||
init_seeds(args.seed)
|
||||
|
||||
# Train loader
|
||||
train_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="train",
|
||||
is_augment=True,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
# Validation loader
|
||||
val_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="val",
|
||||
is_augment=False,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=args.workers
|
||||
)
|
||||
|
||||
# Model
|
||||
model = GeoTr()
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_graph(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
|
||||
],
|
||||
)
|
||||
|
||||
# Data Parallel Mode
|
||||
if RANK == -1 and paddle.device.cuda.device_count() > 1:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
# Scheduler
|
||||
scheduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=args.lr,
|
||||
total_steps=args.epochs * len(train_loader),
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=args.lr / 2.5e5,
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=scheduler,
|
||||
parameters=model.parameters(),
|
||||
)
|
||||
|
||||
# loss function
|
||||
l1_loss_fn = nn.L1Loss()
|
||||
mse_loss_fn = nn.MSELoss()
|
||||
|
||||
# Resume
|
||||
best_fitness, start_epoch = 0.0, 0
|
||||
if args.resume:
|
||||
ckpt = paddle.load(args.resume)
|
||||
model.set_state_dict(ckpt["model"])
|
||||
optimizer.set_state_dict(ckpt["optimizer"])
|
||||
scheduler.set_state_dict(ckpt["scheduler"])
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
|
||||
# Train
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
model.train()
|
||||
|
||||
for i, (img, target) in enumerate(train_loader):
|
||||
img = paddle.to_tensor(img) # NCHW
|
||||
target = paddle.to_tensor(target) # NHWC
|
||||
|
||||
pred = model(img) # NCHW
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar(
|
||||
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/Learning Rate",
|
||||
float(scheduler.get_lr()),
|
||||
epoch * len(train_loader) + i,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
|
||||
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
|
||||
with paddle.no_grad():
|
||||
avg_ssim = paddle.zeros([])
|
||||
avg_ms_ssim = paddle.zeros([])
|
||||
avg_l1_loss = paddle.zeros([])
|
||||
avg_mse_loss = paddle.zeros([])
|
||||
|
||||
for i, (img, target) in enumerate(val_loader):
|
||||
img = paddle.to_tensor(img)
|
||||
target = paddle.to_tensor(target)
|
||||
|
||||
pred = model(img)
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
# predict image
|
||||
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
|
||||
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
|
||||
|
||||
# calculate ssim
|
||||
ssim_val = ssim(out, out_gt, data_range=1.0)
|
||||
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
# calculate fitness
|
||||
avg_ssim += ssim_val
|
||||
avg_ms_ssim += ms_ssim_val
|
||||
avg_l1_loss += loss
|
||||
avg_mse_loss += mse_loss
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
|
||||
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
|
||||
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
|
||||
)
|
||||
|
||||
if use_vdl and i == 0:
|
||||
img_0 = to_image(out[0])
|
||||
img_gt_0 = to_image(out_gt[0])
|
||||
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
|
||||
|
||||
img_1 = to_image(out[1])
|
||||
img_gt_1 = to_image(out_gt[1])
|
||||
img_gt_1 = img_gt_1.astype("uint8")
|
||||
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
|
||||
|
||||
img_2 = to_image(out[2])
|
||||
img_gt_2 = to_image(out_gt[2])
|
||||
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
|
||||
|
||||
avg_ssim /= len(val_loader)
|
||||
avg_ms_ssim /= len(val_loader)
|
||||
avg_l1_loss /= len(val_loader)
|
||||
avg_mse_loss /= len(val_loader)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
|
||||
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
|
||||
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
|
||||
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
|
||||
|
||||
# Save
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"best_fitness": best_fitness,
|
||||
"epoch": epoch,
|
||||
}
|
||||
|
||||
paddle.save(ckpt, str(last))
|
||||
|
||||
if best_fitness < avg_ssim:
|
||||
best_fitness = avg_ssim
|
||||
paddle.save(ckpt, str(best))
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
print_args(vars(args))
|
||||
|
||||
args.save_dir = str(
|
||||
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
|
||||
)
|
||||
|
||||
train(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Hyperparams")
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="The root path of the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img-size",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=288,
|
||||
help="The size of the input image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=65,
|
||||
help="The number of training epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to previous saved model to restart from",
|
||||
)
|
||||
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
|
||||
parser.add_argument(
|
||||
"--project", default=ROOT / "runs/train", help="save to project/name"
|
||||
)
|
||||
parser.add_argument("--name", default="exp", help="save to project/name")
|
||||
parser.add_argument(
|
||||
"--exist-ok",
|
||||
action="store_true",
|
||||
help="existing project/name ok, do not increment",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
|
||||
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -1,10 +0,0 @@
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export FLAGS_logtostderr=0
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 1e-4 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
@@ -1,67 +0,0 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def to_tensor(img: np.ndarray):
|
||||
"""
|
||||
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): The input image as a numpy array.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output tensor.
|
||||
"""
|
||||
img = img[:, :, ::-1]
|
||||
img = img.astype("float32") / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
out: paddle.Tensor = paddle.to_tensor(img)
|
||||
out = paddle.unsqueeze(out, axis=0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_image(x: paddle.Tensor):
|
||||
"""
|
||||
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
out (numpy.ndarray): The output image as a numpy array.
|
||||
"""
|
||||
out: np.ndarray = x.squeeze().numpy()
|
||||
out = out.transpose(1, 2, 0)
|
||||
out = out * 255.0
|
||||
out = out.astype("uint8")
|
||||
out = out[:, :, ::-1]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def unwarp(img, bm, bm_data_format="NCHW"):
|
||||
"""
|
||||
Unwarp an image using a flow field.
|
||||
|
||||
Args:
|
||||
img (paddle.Tensor): The input image.
|
||||
bm (paddle.Tensor): The flow field.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output image.
|
||||
"""
|
||||
_, _, h, w = img.shape
|
||||
|
||||
if bm_data_format == "NHWC":
|
||||
bm = bm.transpose([0, 3, 1, 2])
|
||||
|
||||
# NCHW
|
||||
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
|
||||
# NHWC
|
||||
bm = bm.transpose([0, 2, 3, 1])
|
||||
# NCHW
|
||||
out = F.grid_sample(img, bm)
|
||||
|
||||
return out
|
||||
@@ -1,152 +0,0 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.initializer as init
|
||||
from scipy import special
|
||||
|
||||
|
||||
def weight_init_(
|
||||
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
|
||||
):
|
||||
"""
|
||||
In-place params init function.
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
data = np.ones([3, 4], dtype='float32')
|
||||
linear = paddle.nn.Linear(4, 4)
|
||||
input = paddle.to_tensor(data)
|
||||
print(linear.weight)
|
||||
linear(input)
|
||||
|
||||
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
|
||||
print(linear.weight)
|
||||
"""
|
||||
|
||||
if hasattr(layer, "weight") and layer.weight is not None:
|
||||
getattr(init, func)(**kwargs)(layer.weight)
|
||||
if weight_name is not None:
|
||||
# override weight name
|
||||
layer.weight.name = weight_name
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
init.Constant(bias_value)(layer.bias)
|
||||
if bias_name is not None:
|
||||
# override bias name
|
||||
layer.bias.name = bias_name
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
print(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect."
|
||||
)
|
||||
|
||||
with paddle.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
|
||||
tmp = np.random.uniform(
|
||||
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
|
||||
).astype(np.float32)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tmp = special.erfinv(tmp)
|
||||
|
||||
# Transform to proper mean, std
|
||||
tmp *= std * math.sqrt(2.0)
|
||||
tmp += mean
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tmp = np.clip(tmp, a, b)
|
||||
tensor.set_value(paddle.to_tensor(tmp))
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.dim()
|
||||
if dimensions < 2:
|
||||
raise ValueError(
|
||||
"Fan in and fan out can not be computed for tensor "
|
||||
"with fewer than 2 dimensions"
|
||||
)
|
||||
|
||||
num_input_fmaps = tensor.shape[1]
|
||||
num_output_fmaps = tensor.shape[0]
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ["fan_in", "fan_out"]
|
||||
if mode not in valid_modes:
|
||||
raise ValueError(
|
||||
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
|
||||
)
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == "fan_in" else fan_out
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
linear_fns = [
|
||||
"linear",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
]
|
||||
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
|
||||
return 1
|
||||
elif nonlinearity == "tanh":
|
||||
return 5.0 / 3
|
||||
elif nonlinearity == "relu":
|
||||
return math.sqrt(2.0)
|
||||
elif nonlinearity == "leaky_relu":
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif (
|
||||
not isinstance(param, bool)
|
||||
and isinstance(param, int)
|
||||
or isinstance(param, float)
|
||||
):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope**2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
with paddle.no_grad():
|
||||
paddle.nn.initializer.Normal(0, std)(tensor)
|
||||
return tensor
|
||||
26
docker-compose.dev.yml
Normal file
26
docker-compose.dev.yml
Normal file
@@ -0,0 +1,26 @@
|
||||
services:
|
||||
fcb_ai_dev:
|
||||
image: fcb_ai_dev:0.0.10
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile.dev
|
||||
# 容器名称,可自定义
|
||||
container_name: fcb_ai_dev
|
||||
hostname: fcb_ai_dev
|
||||
# 始终重启容器
|
||||
restart: always
|
||||
# 端口映射,根据需要修改主机端口
|
||||
ports:
|
||||
- "8022:22"
|
||||
# 数据卷映射,根据实际路径修改
|
||||
volumes:
|
||||
- ./log:/app/log
|
||||
- ./model:/app/model
|
||||
# 启用GPU支持
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0', '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
@@ -1,6 +1,6 @@
|
||||
x-env:
|
||||
&template
|
||||
image: fcb_photo_review:1.14.6
|
||||
image: fcb_photo_review:1.15.7
|
||||
restart: always
|
||||
|
||||
x-review:
|
||||
@@ -31,30 +31,12 @@ x-mask:
|
||||
driver: 'nvidia'
|
||||
|
||||
services:
|
||||
det_api:
|
||||
<<: *template
|
||||
build:
|
||||
context: .
|
||||
container_name: det_api
|
||||
hostname: det_api
|
||||
volumes:
|
||||
- ./log:/app/log
|
||||
- ./model:/app/model
|
||||
# command: [ 'det_api.py' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
photo_review_1:
|
||||
<<: *review_template
|
||||
build:
|
||||
context: .
|
||||
container_name: photo_review_1
|
||||
hostname: photo_review_1
|
||||
depends_on:
|
||||
- det_api
|
||||
command: [ 'photo_review.py', '--clean', 'True' ]
|
||||
|
||||
photo_review_2:
|
||||
|
||||
BIN
document/Linux下搭建和部署Paddle相关项目.docx
Normal file
BIN
document/Linux下搭建和部署Paddle相关项目.docx
Normal file
Binary file not shown.
BIN
document/OCR工作效率统计.xlsx
Normal file
BIN
document/OCR工作效率统计.xlsx
Normal file
Binary file not shown.
153
document/PaddleOCR命令.md
Normal file
153
document/PaddleOCR命令.md
Normal file
@@ -0,0 +1,153 @@
|
||||
# PaddleOCR
|
||||
|
||||
------
|
||||
|
||||
## 数据集
|
||||
|
||||
该部分内容均在PPOCRLabel目录下进行
|
||||
|
||||
```bash
|
||||
# 进入PPOCRLabel目录
|
||||
cd .\PPOCRLabel\
|
||||
```
|
||||
|
||||
### 打标
|
||||
|
||||
可以对PPOCRLabel.py直接使用PyCharm中的Run,但是默认是英文的
|
||||
|
||||
```bash
|
||||
# 以中文运行打标应用
|
||||
python PPOCRLabel.py --lang ch
|
||||
# 含有关键词提取的打标
|
||||
python PPOCRLabel.py --lang ch --kie True
|
||||
```
|
||||
|
||||
### 划分数据集
|
||||
|
||||
```bash
|
||||
python gen_ocr_train_val_test.py --trainValTestRatio 6:2:2 --datasetRootPath ../train_data/drivingData
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
## 检测模型
|
||||
|
||||
先回到项目根目录
|
||||
|
||||
### 训练
|
||||
|
||||
```bash
|
||||
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
python tools/infer_det.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/1.jpg
|
||||
```
|
||||
|
||||
### 恢复训练
|
||||
|
||||
```bash
|
||||
python tools/train.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.checkpoints=./output/det_v4_bankcard/latest
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
## 识别模型
|
||||
|
||||
### 训练
|
||||
|
||||
```bash
|
||||
python tools/train.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml
|
||||
```
|
||||
|
||||
### 测试
|
||||
|
||||
```bash
|
||||
python tools/infer_rec.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams Global.infer_img=train_data/drivingData/crop_img/1_crop_0.jpg
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
## 推理模型
|
||||
|
||||
### 检测模型转换
|
||||
|
||||
```bash
|
||||
python tools/export_model.py -c configs/det/ch_PP-OCRv4/ch_PP-OCRv4_det_student.yml -o Global.pretrained_model=output/det_v4_bankcard/best_accuracy.pdparams
|
||||
```
|
||||
|
||||
### 识别模型转换
|
||||
|
||||
```bash
|
||||
python tools/export_model.py -c configs/rec/PP-OCRv4/ch_PP-OCRv4_rec_ampO2_ultra.yml -o Global.pretrained_model=output/rec_v4_bankcard/best_accuracy.pdparams
|
||||
```
|
||||
|
||||
### 检测识别测试
|
||||
|
||||
```bash
|
||||
python tools/infer/predict_system.py --det_model_dir=inference_model/det_v4_bankcard --rec_model_dir=inference_model/rec_v4_bankcard --rec_char_dict_path=ppocr/utils/num_dict.txt --image_dir=train_data/drivingData/1.jpg
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
## 移动端模型
|
||||
|
||||
### 检测模型转换
|
||||
|
||||
```bash
|
||||
paddle_lite_opt --model_file=inference_model/det_v4_bankcard/inference.pdmodel --param_file=inference_model/det_v4_bankcard/inference.pdiparams --optimize_out=inference_model/det_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
|
||||
```
|
||||
|
||||
### 识别模型转换
|
||||
|
||||
```bash
|
||||
paddle_lite_opt --model_file=inference_model/rec_v4_bankcard/inference.pdmodel --param_file=inference_model/rec_v4_bankcard/inference.pdiparams --optimize_out=inference_model/rec_v4_nb_bankcard --valid_targets=arm --optimize_out_type=naive_buffer
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
------
|
||||
|
||||
# PaddleNLP
|
||||
|
||||
## 数据集
|
||||
|
||||
使用Label Studio进行数据标注,安装过程省略
|
||||
|
||||
```bash
|
||||
# 打开Anaconda Prompt
|
||||
# 激活安装Label Studio的环境
|
||||
conda activate label-studio
|
||||
# 启动Label Studio
|
||||
label-studio start
|
||||
```
|
||||
|
||||
[打标流程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/applications/information_extraction/label_studio_doc.md)
|
||||
|
||||
### 数据转换
|
||||
|
||||
```bash
|
||||
# 进入PaddleNLP\applications\information_extraction后执行
|
||||
python label_studio.py --label_studio_file ./document/data/label_studio.json --save_dir ./document/data --splits 0.8 0.1 0.1 --task_type ext
|
||||
```
|
||||
|
||||
|
||||
|
||||
------
|
||||
|
||||
## 训练模型
|
||||
|
||||
```bash
|
||||
# 进入PaddleNLP\applications\information_extraction\document后执行(双卡训练)
|
||||
python -u -m paddle.distributed.launch --gpus "0,1" finetune.py --device gpu --logging_steps 5 --save_steps 25 --eval_steps 25 --seed 42 --model_name_or_path uie-x-base --output_dir ./checkpoint/model_best --train_path data/train.txt --dev_path data/dev.txt --max_seq_len 512 --per_device_train_batch_size 8 --per_device_eval_batch_size 8 --num_train_epochs 10 --learning_rate 1e-5 --do_train --do_eval --do_export --export_model_dir ./checkpoint/model_best --overwrite_output_dir --disable_tqdm False --metric_for_best_model eval_f1 --load_best_model_at_end True --save_total_limit 1
|
||||
```
|
||||
|
||||
------
|
||||
|
||||
参考:
|
||||
|
||||
[PaddleOCR训练属于自己的模型详细教程](https://blog.csdn.net/qq_52852432/article/details/131817619?utm_medium=distribute.pc_relevant.none-task-blog-2~default~baidujs_baidulandingword~default-0-131817619-blog-124628731.235^v40^pc_relevant_3m_sort_dl_base1&spm=1001.2101.3001.4242.1&utm_relevant_index=3)
|
||||
[端侧部署](https://github.com/PaddlePaddle/PaddleOCR/blob/release/2.7/deploy/lite/readme_ch.md)
|
||||
[PaddleNLP关键信息抽取](https://blog.csdn.net/z5z5z5z56/article/details/130346646)
|
||||
329
document/paddle镜像自带依赖.md
Normal file
329
document/paddle镜像自带依赖.md
Normal file
@@ -0,0 +1,329 @@
|
||||
anyio 4.0.0
|
||||
astor 0.8.1
|
||||
certifi 2019.11.28
|
||||
chardet 3.0.4
|
||||
dbus-python 1.2.16
|
||||
decorator 5.1.1
|
||||
distro-info 0.23+ubuntu1.1
|
||||
exceptiongroup 1.1.3
|
||||
h11 0.14.0
|
||||
httpcore 1.0.2
|
||||
httpx 0.25.1
|
||||
idna 2.8
|
||||
numpy 1.26.2
|
||||
opt-einsum 3.3.0
|
||||
paddlepaddle-gpu 2.6.1.post120
|
||||
Pillow 10.1.0
|
||||
pip 24.0
|
||||
protobuf 4.25.0
|
||||
PyGObject 3.36.0
|
||||
python-apt 2.0.1+ubuntu0.20.4.1
|
||||
requests 2.22.0
|
||||
requests-unixsocket 0.2.0
|
||||
setuptools 68.2.2
|
||||
six 1.14.0
|
||||
sniffio 1.3.0
|
||||
unattended-upgrades 0.1
|
||||
urllib3 1.25.8
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlepaddle/paddle:3.1.0-gpu-cuda12.9-cudnn9.9:
|
||||
|
||||
python:3.10.12
|
||||
|
||||
Package Version
|
||||
------------------------ ----------
|
||||
anyio 4.9.0
|
||||
certifi 2025.6.15
|
||||
decorator 5.2.1
|
||||
exceptiongroup 1.3.0
|
||||
h11 0.16.0
|
||||
httpcore 1.0.9
|
||||
httpx 0.28.1
|
||||
idna 3.10
|
||||
networkx 3.4.2
|
||||
numpy 2.2.6
|
||||
nvidia-cublas-cu12 12.9.0.13
|
||||
nvidia-cuda-cccl-cu12 12.9.27
|
||||
nvidia-cuda-cupti-cu12 12.9.19
|
||||
nvidia-cuda-nvrtc-cu12 12.9.41
|
||||
nvidia-cuda-runtime-cu12 12.9.37
|
||||
nvidia-cudnn-cu12 9.9.0.52
|
||||
nvidia-cufft-cu12 11.4.0.6
|
||||
nvidia-cufile-cu12 1.14.0.30
|
||||
nvidia-curand-cu12 10.3.10.19
|
||||
nvidia-cusolver-cu12 11.7.4.40
|
||||
nvidia-cusparse-cu12 12.5.9.5
|
||||
nvidia-cusparselt-cu12 0.7.1
|
||||
nvidia-nccl-cu12 2.26.5
|
||||
nvidia-nvjitlink-cu12 12.9.41
|
||||
nvidia-nvtx-cu12 12.9.19
|
||||
opt-einsum 3.3.0
|
||||
paddlepaddle-gpu 3.1.0
|
||||
pillow 11.2.1
|
||||
pip 25.1.1
|
||||
protobuf 6.31.1
|
||||
setuptools 59.6.0
|
||||
sniffio 1.3.1
|
||||
typing_extensions 4.14.0
|
||||
wheel 0.37.1
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
ccr-2vdh3abv-pub.cnc.bj.baidubce.com/paddlex/paddlex:paddlex3.1.2-paddlepaddle3.0.0-gpu-cuda12.6-cudnn9.5-trt10.5
|
||||
|
||||
python:3.10.18
|
||||
|
||||
Package Version Editable project location
|
||||
------------------------- -------------------- -------------------------
|
||||
aiohappyeyeballs 2.6.1
|
||||
aiohttp 3.12.13
|
||||
aiosignal 1.4.0
|
||||
aistudio_sdk 0.3.5
|
||||
albucore 0.0.13+pdx
|
||||
albumentations 1.4.10+pdx
|
||||
alembic 1.16.2
|
||||
annotated-types 0.7.0
|
||||
anyio 4.9.0
|
||||
astor 0.8.1
|
||||
asttokens 3.0.0
|
||||
async-timeout 4.0.3
|
||||
attrdict3 2.0.2
|
||||
attrs 25.3.0
|
||||
babel 2.17.0
|
||||
bce-python-sdk 0.9.35
|
||||
beautifulsoup4 4.13.4
|
||||
blinker 1.9.0
|
||||
cachetools 6.1.0
|
||||
certifi 2019.11.28
|
||||
cffi 1.17.1
|
||||
chardet 3.0.4
|
||||
charset-normalizer 3.4.2
|
||||
chinese-calendar 1.8.0
|
||||
click 8.2.1
|
||||
cloudpickle 3.1.1
|
||||
colorama 0.4.6
|
||||
colorlog 6.9.0
|
||||
ConfigSpace 1.2.1
|
||||
contourpy 1.3.2
|
||||
cssselect 1.3.0
|
||||
cssutils 2.11.1
|
||||
cycler 0.12.1
|
||||
Cython 3.1.2
|
||||
dataclasses-json 0.6.7
|
||||
datasets 3.6.0
|
||||
dbus-python 1.2.16
|
||||
decorator 5.2.1
|
||||
decord 0.6.0
|
||||
descartes 1.1.0
|
||||
dill 0.3.4
|
||||
distro 1.9.0
|
||||
distro-info 0.23+ubuntu1.1
|
||||
easydict 1.13
|
||||
einops 0.8.1
|
||||
et_xmlfile 2.0.0
|
||||
exceptiongroup 1.2.2
|
||||
executing 2.2.0
|
||||
faiss-cpu 1.8.0.post1
|
||||
fastapi 0.116.0
|
||||
filelock 3.18.0
|
||||
fire 0.7.0
|
||||
FLAML 2.3.5
|
||||
Flask 3.1.1
|
||||
flask-babel 4.0.0
|
||||
fonttools 4.58.5
|
||||
frozenlist 1.7.0
|
||||
fsspec 2025.3.0
|
||||
ftfy 6.3.1
|
||||
future 1.0.0
|
||||
gast 0.3.3
|
||||
GPUtil 1.4.0
|
||||
greenlet 3.2.3
|
||||
h11 0.14.0
|
||||
h5py 3.14.0
|
||||
hf-xet 1.1.5
|
||||
hpbandster 0.7.4
|
||||
httpcore 1.0.7
|
||||
httpx 0.28.1
|
||||
httpx-sse 0.4.1
|
||||
huggingface-hub 0.33.2
|
||||
idna 2.8
|
||||
imageio 2.37.0
|
||||
imagesize 1.4.1
|
||||
imgaug 0.4.0+pdx
|
||||
ipython 8.37.0
|
||||
itsdangerous 2.2.0
|
||||
jedi 0.19.2
|
||||
jieba 0.42.1
|
||||
Jinja2 3.1.6
|
||||
jiter 0.10.0
|
||||
joblib 1.5.1
|
||||
jsonpatch 1.33
|
||||
jsonpointer 3.0.0
|
||||
jsonschema 4.24.0
|
||||
jsonschema-specifications 2025.4.1
|
||||
kiwisolver 1.4.8
|
||||
langchain 0.3.26
|
||||
langchain-community 0.3.27
|
||||
langchain-core 0.3.68
|
||||
langchain-openai 0.3.27
|
||||
langchain-text-splitters 0.3.8
|
||||
langsmith 0.4.4
|
||||
lapx 0.5.11.post1
|
||||
lazy_loader 0.4
|
||||
llvmlite 0.44.0
|
||||
lmdb 1.6.2
|
||||
lxml 6.0.0
|
||||
Mako 1.3.10
|
||||
markdown-it-py 3.0.0
|
||||
MarkupSafe 3.0.2
|
||||
marshmallow 3.26.1
|
||||
matplotlib 3.5.3
|
||||
matplotlib-inline 0.1.7
|
||||
mdurl 0.1.2
|
||||
more-itertools 10.7.0
|
||||
motmetrics 1.4.0
|
||||
msgpack 1.1.1
|
||||
multidict 6.6.3
|
||||
multiprocess 0.70.12.2
|
||||
mypy_extensions 1.1.0
|
||||
netifaces 0.11.0
|
||||
networkx 3.4.2
|
||||
numba 0.61.2
|
||||
numpy 1.24.4
|
||||
nuscenes-devkit 1.1.11+pdx
|
||||
onnx 1.17.0
|
||||
onnxoptimizer 0.3.13
|
||||
openai 1.93.1
|
||||
opencv-contrib-python 4.10.0.84
|
||||
openpyxl 3.1.5
|
||||
opt-einsum 3.3.0
|
||||
optuna 4.4.0
|
||||
orjson 3.10.18
|
||||
packaging 24.2
|
||||
paddle2onnx 2.0.2rc3
|
||||
paddle3d 0.0.0
|
||||
paddleclas 2.6.0
|
||||
paddledet 0.0.0
|
||||
paddlefsl 1.1.0
|
||||
paddlenlp 2.8.0.post0
|
||||
paddlepaddle-gpu 3.0.0
|
||||
paddleseg 0.0.0.dev0
|
||||
paddlets 1.1.0
|
||||
paddlex 3.1.2 /root/PaddleX
|
||||
pandas 1.3.5
|
||||
parso 0.8.4
|
||||
patsy 1.0.1
|
||||
pexpect 4.9.0
|
||||
pillow 11.1.0
|
||||
pip 25.1.1
|
||||
polygraphy 0.49.24
|
||||
ppvideo 2.3.0
|
||||
premailer 3.10.0
|
||||
prettytable 3.16.0
|
||||
prompt_toolkit 3.0.51
|
||||
propcache 0.3.2
|
||||
protobuf 6.30.1
|
||||
psutil 7.0.0
|
||||
ptyprocess 0.7.0
|
||||
pure_eval 0.2.3
|
||||
py-cpuinfo 9.0.0
|
||||
pyarrow 20.0.0
|
||||
pybind11 2.13.6
|
||||
pybind11-stubgen 2.5.1
|
||||
pyclipper 1.3.0.post6
|
||||
pycocotools 2.0.8
|
||||
pycparser 2.22
|
||||
pycryptodome 3.23.0
|
||||
pydantic 2.11.7
|
||||
pydantic_core 2.33.2
|
||||
pydantic-settings 2.10.1
|
||||
Pygments 2.19.2
|
||||
PyGObject 3.36.0
|
||||
PyMatting 1.1.14
|
||||
pyod 2.0.5
|
||||
pypandoc 1.15
|
||||
pyparsing 3.2.3
|
||||
pypdfium2 4.30.1
|
||||
pyquaternion 0.9.9
|
||||
Pyro4 4.82
|
||||
python-apt 2.0.1+ubuntu0.20.4.1
|
||||
python-dateutil 2.9.0.post0
|
||||
python-docx 1.2.0
|
||||
python-dotenv 1.1.1
|
||||
pytz 2025.2
|
||||
PyWavelets 1.3.0
|
||||
PyYAML 6.0.2
|
||||
RapidFuzz 3.13.0
|
||||
rarfile 4.2
|
||||
ray 2.47.1
|
||||
referencing 0.36.2
|
||||
regex 2024.11.6
|
||||
requests 2.32.4
|
||||
requests-toolbelt 1.0.0
|
||||
requests-unixsocket 0.2.0
|
||||
rich 14.0.0
|
||||
rpds-py 0.26.0
|
||||
ruamel.yaml 0.18.14
|
||||
ruamel.yaml.clib 0.2.12
|
||||
safetensors 0.5.3
|
||||
scikit-image 0.25.2
|
||||
scikit-learn 1.3.2
|
||||
scipy 1.15.3
|
||||
seaborn 0.13.2
|
||||
sentencepiece 0.2.0
|
||||
seqeval 1.2.2
|
||||
serpent 1.41
|
||||
setuptools 68.2.2
|
||||
shap 0.48.0
|
||||
Shapely 1.8.5.post1
|
||||
shellingham 1.5.4
|
||||
six 1.14.0
|
||||
sklearn 0.0
|
||||
slicer 0.0.8
|
||||
sniffio 1.3.1
|
||||
soundfile 0.13.1
|
||||
soupsieve 2.7
|
||||
SQLAlchemy 2.0.41
|
||||
stack-data 0.6.3
|
||||
starlette 0.46.2
|
||||
statsmodels 0.14.1
|
||||
tenacity 9.1.2
|
||||
tensorboardX 2.6.4
|
||||
tensorrt 10.5.0
|
||||
termcolor 3.1.0
|
||||
terminaltables 3.1.10
|
||||
threadpoolctl 3.6.0
|
||||
tifffile 2025.5.10
|
||||
tiktoken 0.9.0
|
||||
tokenizers 0.19.1
|
||||
tomli 2.2.1
|
||||
tool_helpers 0.1.2
|
||||
tqdm 4.67.1
|
||||
traitlets 5.14.3
|
||||
typeguard 4.4.4
|
||||
typer 0.16.0
|
||||
typing_extensions 4.14.1
|
||||
typing-inspect 0.9.0
|
||||
typing-inspection 0.4.1
|
||||
tzdata 2025.2
|
||||
ujson 5.10.0
|
||||
unattended-upgrades 0.1
|
||||
urllib3 1.25.8
|
||||
uvicorn 0.35.0
|
||||
visualdl 2.5.3
|
||||
Wand 0.6.13
|
||||
wcwidth 0.2.13
|
||||
Werkzeug 3.1.3
|
||||
xmltodict 0.14.2
|
||||
xxhash 3.5.0
|
||||
yacs 0.1.8
|
||||
yarl 1.20.1
|
||||
zstandard 0.23.0
|
||||
BIN
document/关于使用PaddleOCR训练模型的进展情况说明.docx
Normal file
BIN
document/关于使用PaddleOCR训练模型的进展情况说明.docx
Normal file
Binary file not shown.
BIN
document/医保类型识别结果.xlsx
Normal file
BIN
document/医保类型识别结果.xlsx
Normal file
Binary file not shown.
@@ -8,6 +8,7 @@ LOG_PATHS = [
|
||||
f"log/{HOSTNAME}/ucloud",
|
||||
f"log/{HOSTNAME}/error",
|
||||
f"log/{HOSTNAME}/qr",
|
||||
f"log/{HOSTNAME}/sql",
|
||||
]
|
||||
for path in LOG_PATHS:
|
||||
if not os.path.exists(path):
|
||||
@@ -74,6 +75,16 @@ LOGGING_CONFIG = {
|
||||
'backupCount': 14,
|
||||
'encoding': 'utf-8',
|
||||
},
|
||||
'sql': {
|
||||
'class': 'logging.handlers.TimedRotatingFileHandler',
|
||||
'level': 'INFO',
|
||||
'formatter': 'standard',
|
||||
'filename': f'log/{HOSTNAME}/sql/fcb_photo_review_sql.log',
|
||||
'when': 'midnight',
|
||||
'interval': 1,
|
||||
'backupCount': 14,
|
||||
'encoding': 'utf-8',
|
||||
},
|
||||
},
|
||||
|
||||
# loggers定义了日志记录器
|
||||
@@ -98,5 +109,10 @@ LOGGING_CONFIG = {
|
||||
'level': 'DEBUG',
|
||||
'propagate': False,
|
||||
},
|
||||
'sql': {
|
||||
'handlers': ['console', 'sql'],
|
||||
'level': 'DEBUG',
|
||||
'propagate': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -62,7 +62,7 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
|
||||
captured_image, offset_x, offset_y = image_util.expand_to_a4_size(captured_image)
|
||||
cv2.imwrite(temp_file.name, captured_image)
|
||||
try:
|
||||
layouts = util.get_ocr_layout(OCR, temp_file.name)
|
||||
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
|
||||
except TypeError:
|
||||
# 如果是类型错误,大概率是没识别到文字
|
||||
layouts = []
|
||||
@@ -100,7 +100,7 @@ def get_mask_layout(image, name, id_card_num):
|
||||
result = []
|
||||
try:
|
||||
try:
|
||||
layouts = util.get_ocr_layout(OCR, temp_file.name)
|
||||
layouts, _ = util.get_ocr_layout(OCR, temp_file.name)
|
||||
# layouts = OCR.parse({"doc": temp_file.name})["layout"]
|
||||
except TypeError:
|
||||
# 如果是类型错误,大概率是没识别到文字
|
||||
@@ -160,6 +160,8 @@ def get_mask_layout(image, name, id_card_num):
|
||||
result += _find_boxes_by_keys(ID_CARD_NUM_KEYS)
|
||||
|
||||
return result
|
||||
except MemoryError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.error("涂抹时出错!", exc_info=e)
|
||||
return result
|
||||
@@ -196,7 +198,9 @@ def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
|
||||
return do_mask, i
|
||||
|
||||
# 打开图片
|
||||
image = image_util.read(img_url)
|
||||
image, _ = image_util.read(img_url)
|
||||
if image is None:
|
||||
return False, image
|
||||
original_image = image
|
||||
is_masked, image = _mask(image, name, id_card_num, color)
|
||||
if not is_masked:
|
||||
|
||||
@@ -23,7 +23,7 @@ def check_error(error_ocr):
|
||||
|
||||
image = mask_photo(img_url, name, id_card_num, (0, 0, 0))[1]
|
||||
final_img_url = ufile.get_private_url(error_ocr.cfjaddress, "drg100")
|
||||
final_image = image_util.read(final_img_url)
|
||||
final_image, _ = image_util.read(final_img_url)
|
||||
return image_util.combined(final_image, image)
|
||||
|
||||
|
||||
|
||||
@@ -13,14 +13,14 @@ from photo_review import auto_photo_review, SEND_ERROR_EMAIL
|
||||
|
||||
# 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生
|
||||
if __name__ == '__main__':
|
||||
program_name = '照片审核自动识别脚本'
|
||||
program_name = "照片审核自动识别脚本"
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--clean", default=False, type=bool, help="是否将识别中的案子改为待识别状态")
|
||||
args = parser.parse_args()
|
||||
if args.clean:
|
||||
# 主要用于启动时,清除仍在涂抹中的案子
|
||||
# 主要用于启动时,清除仍在识别中的案子
|
||||
session = MysqlSession()
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == "2").values(exsuccess_flag="1"))
|
||||
session.execute(update_flag)
|
||||
@@ -34,7 +34,7 @@ if __name__ == '__main__':
|
||||
logging.info(f"【{program_name}】开始运行")
|
||||
auto_photo_review.main()
|
||||
except Exception as e:
|
||||
error_logger = logging.getLogger('error')
|
||||
error_logger = logging.getLogger("error")
|
||||
error_logger.error(traceback.format_exc())
|
||||
if SEND_ERROR_EMAIL:
|
||||
send_error_email(program_name, repr(e), traceback.format_exc())
|
||||
|
||||
@@ -2,9 +2,9 @@ import jieba
|
||||
from paddlenlp import Taskflow
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
'''
|
||||
"""
|
||||
项目配置
|
||||
'''
|
||||
"""
|
||||
# 每次从数据库获取的案子数量
|
||||
PHHD_BATCH_SIZE = 10
|
||||
# 没有查询到案子的等待时间(分钟)
|
||||
@@ -18,35 +18,35 @@ LAYOUT_ANALYSIS = False
|
||||
信息抽取关键词配置
|
||||
"""
|
||||
# 患者姓名
|
||||
PATIENT_NAME = ['患者姓名']
|
||||
PATIENT_NAME = ["患者姓名"]
|
||||
# 入院日期
|
||||
ADMISSION_DATE = ['入院日期']
|
||||
ADMISSION_DATE = ["入院日期"]
|
||||
# 出院日期
|
||||
DISCHARGE_DATE = ['出院日期']
|
||||
DISCHARGE_DATE = ["出院日期"]
|
||||
# 发生医疗费
|
||||
MEDICAL_EXPENSES = ['费用总额']
|
||||
MEDICAL_EXPENSES = ["费用总额"]
|
||||
# 个人现金支付
|
||||
PERSONAL_CASH_PAYMENT = ['个人现金支付']
|
||||
PERSONAL_CASH_PAYMENT = ["个人现金支付"]
|
||||
# 个人账户支付
|
||||
PERSONAL_ACCOUNT_PAYMENT = ['个人账户支付']
|
||||
PERSONAL_ACCOUNT_PAYMENT = ["个人账户支付"]
|
||||
# 个人自费金额
|
||||
PERSONAL_FUNDED_AMOUNT = ['自费金额', '个人自费']
|
||||
PERSONAL_FUNDED_AMOUNT = ["自费金额", "个人自费"]
|
||||
# 医保类别
|
||||
MEDICAL_INSURANCE_TYPE = ['医保类型']
|
||||
MEDICAL_INSURANCE_TYPE = ["医保类型"]
|
||||
# 就诊医院
|
||||
HOSPITAL = ['医院']
|
||||
HOSPITAL = ["医院"]
|
||||
# 就诊科室
|
||||
DEPARTMENT = ['科室']
|
||||
DEPARTMENT = ["科室"]
|
||||
# 主治医生
|
||||
DOCTOR = ['主治医生']
|
||||
DOCTOR = ["主治医生"]
|
||||
# 住院号
|
||||
ADMISSION_ID = ['住院号']
|
||||
ADMISSION_ID = ["住院号"]
|
||||
# 医保结算单号码
|
||||
SETTLEMENT_ID = ['医保结算单号码']
|
||||
SETTLEMENT_ID = ["医保结算单号码"]
|
||||
# 年龄
|
||||
AGE = ['年龄']
|
||||
AGE = ["年龄"]
|
||||
# 大写总额
|
||||
UPPERCASE_MEDICAL_EXPENSES = ['大写总额']
|
||||
UPPERCASE_MEDICAL_EXPENSES = ["大写总额"]
|
||||
|
||||
SETTLEMENT_LIST_SCHEMA = \
|
||||
(PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES + PERSONAL_CASH_PAYMENT
|
||||
@@ -58,47 +58,55 @@ DISCHARGE_RECORD_SCHEMA = \
|
||||
|
||||
COST_LIST_SCHEMA = PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES
|
||||
|
||||
'''
|
||||
"""
|
||||
别名配置
|
||||
'''
|
||||
"""
|
||||
# 使用别名中的value替换key。考虑到效率问题,只会替换第一个匹配到的key。
|
||||
HOSPITAL_ALIAS = {
|
||||
'沐阳': ['沭阳'],
|
||||
'连水': ['涟水'],
|
||||
'唯宁': ['睢宁'], # 雕宁
|
||||
'九〇四': ['904'],
|
||||
'漂水': ['溧水'],
|
||||
"沐阳": ["沭阳"],
|
||||
"连水": ["涟水"],
|
||||
"唯宁": ["睢宁"], # 雕宁
|
||||
"九〇四": ["904"],
|
||||
"漂水": ["溧水"],
|
||||
}
|
||||
DEPARTMENT_ALIAS = {
|
||||
'耳鼻喉': ['耳鼻咽喉'],
|
||||
'急症': ['急诊'],
|
||||
"耳鼻喉": ["耳鼻咽喉"],
|
||||
"急症": ["急诊"],
|
||||
}
|
||||
|
||||
'''
|
||||
"""
|
||||
搜索过滤配置
|
||||
'''
|
||||
"""
|
||||
# 默认会过滤单字
|
||||
HOSPITAL_FILTER = ['医院', '人民', '第一', '第二', '第三', '大学', '附属']
|
||||
HOSPITAL_FILTER = ["医院", "人民", "第一", "第二", "第三", "大学", "附属"]
|
||||
|
||||
DEPARTMENT_FILTER = ['医', '伤', '西', '新']
|
||||
DEPARTMENT_FILTER = ["医", "伤", "西", "新"]
|
||||
|
||||
'''
|
||||
"""
|
||||
分词配置
|
||||
'''
|
||||
jieba.suggest_freq(('肿瘤', '医院'), True)
|
||||
jieba.suggest_freq(('骨', '伤'), True)
|
||||
jieba.suggest_freq(('感染', '性'), True)
|
||||
jieba.suggest_freq(('胆', '道'), True)
|
||||
jieba.suggest_freq(('脾', '胃'), True)
|
||||
"""
|
||||
jieba.suggest_freq(("肿瘤", "医院"), True)
|
||||
jieba.suggest_freq(("骨", "伤"), True)
|
||||
jieba.suggest_freq(("感染", "性"), True)
|
||||
jieba.suggest_freq(("胆", "道"), True)
|
||||
jieba.suggest_freq(("脾", "胃"), True)
|
||||
|
||||
'''
|
||||
"""
|
||||
模型配置
|
||||
'''
|
||||
SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA, model='uie-x-base',
|
||||
task_path='model/settlement_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
DISCHARGE_IE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, model='uie-x-base',
|
||||
task_path='model/discharge_record_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
COST_IE = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base', device_id=1,
|
||||
task_path='model/cost_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
"""
|
||||
SETTLEMENT_IE = Taskflow("information_extraction", schema=SETTLEMENT_LIST_SCHEMA, model="uie-x-base",
|
||||
task_path="model/settlement_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
|
||||
DISCHARGE_IE = Taskflow("information_extraction", schema=DISCHARGE_RECORD_SCHEMA, model="uie-x-base",
|
||||
task_path="model/discharge_record_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
|
||||
COST_IE = Taskflow("information_extraction", schema=COST_LIST_SCHEMA, model="uie-x-base", device_id=1,
|
||||
task_path="model/cost_list_model", layout_analysis=LAYOUT_ANALYSIS, precision="fp16")
|
||||
|
||||
OCR = PaddleOCR(use_angle_cls=False, show_log=False, gpu_id=1, det_db_box_thresh=0.3)
|
||||
OCR = PaddleOCR(
|
||||
device="gpu:0",
|
||||
ocr_version="PP-OCRv4",
|
||||
use_textline_orientation=False,
|
||||
# 检测像素阈值,输出的概率图中,得分大于该阈值的像素点才会被认为是文字像素点
|
||||
text_det_thresh=0.1,
|
||||
# 检测框阈值,检测结果边框内,所有像素点的平均得分大于该阈值时,该结果会被认为是文字区域
|
||||
text_det_box_thresh=0.3,
|
||||
)
|
||||
@@ -22,11 +22,11 @@ from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_E
|
||||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
|
||||
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES, \
|
||||
UPPERCASE_MEDICAL_EXPENSES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, DEPARTMENT_FILTER
|
||||
from ucloud import ufile
|
||||
from ucloud import ufile, BUCKET
|
||||
from util import image_util, util, html_util
|
||||
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
|
||||
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_id, handle_age, parse_money, \
|
||||
parse_hospital
|
||||
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_age, parse_money, \
|
||||
parse_hospital, handle_doctor, handle_text, handle_admission_id, handle_settlement_id
|
||||
|
||||
|
||||
# 合并信息抽取结果
|
||||
@@ -36,18 +36,25 @@ def merge_result(result1, result2):
|
||||
return result1
|
||||
|
||||
|
||||
def ie_temp_image(ie, ocr, image):
|
||||
def ie_temp_image(ie, ocr, image, is_screenshot=False):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
|
||||
ie_result = []
|
||||
ocr_pure_text = ''
|
||||
angle = '0'
|
||||
try:
|
||||
layout = util.get_ocr_layout(ocr, temp_file.name)
|
||||
layout, angle = util.get_ocr_layout(ocr, temp_file.name, is_screenshot)
|
||||
if not layout:
|
||||
# 无识别结果
|
||||
ie_result = []
|
||||
else:
|
||||
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
|
||||
for lay in layout:
|
||||
ocr_pure_text += lay[1]
|
||||
except MemoryError as e:
|
||||
# 显存不足时应该抛出错误,让程序重启,同时释放显存
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.error("信息抽取时出错", exc_info=e)
|
||||
finally:
|
||||
@@ -55,7 +62,7 @@ def ie_temp_image(ie, ocr, image):
|
||||
os.remove(temp_file.name)
|
||||
except Exception as e:
|
||||
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
|
||||
return ie_result
|
||||
return ie_result, ocr_pure_text, angle
|
||||
|
||||
|
||||
# 关键信息提取
|
||||
@@ -147,12 +154,16 @@ def get_better_image_from_qrcode(image, image_id, dpi=150):
|
||||
# 关键信息提取
|
||||
def information_extraction(ie, phrecs, identity):
|
||||
result = {}
|
||||
ocr_text = ''
|
||||
for phrec in phrecs:
|
||||
img_path = ufile.get_private_url(phrec.cfjaddress)
|
||||
if not img_path:
|
||||
continue
|
||||
|
||||
image = image_util.read(img_path)
|
||||
image, exif_data = image_util.read(img_path)
|
||||
if image is None:
|
||||
# 图片可能因为某些原因获取不到
|
||||
continue
|
||||
|
||||
# 尝试从二维码中获取高清图片
|
||||
better_image, text = get_better_image_from_qrcode(image, phrec.cfjaddress)
|
||||
@@ -165,7 +176,7 @@ def information_extraction(ie, phrecs, identity):
|
||||
if text:
|
||||
info_extract = ie(text)[0]
|
||||
else:
|
||||
info_extract = ie_temp_image(ie, OCR, image)
|
||||
info_extract = ie_temp_image(ie, OCR, image, True)[0]
|
||||
ie_result = {'result': info_extract, 'angle': '0'}
|
||||
|
||||
now = util.get_default_datetime()
|
||||
@@ -183,25 +194,20 @@ def information_extraction(ie, phrecs, identity):
|
||||
|
||||
result = merge_result(result, ie_result['result'])
|
||||
else:
|
||||
is_screenshot = image_util.is_screenshot(image, exif_data)
|
||||
target_images = []
|
||||
# target_images += detector.request_book_areas(image) # 识别文档区域并裁剪
|
||||
if not target_images:
|
||||
target_images.append(image) # 识别失败
|
||||
angle_count = defaultdict(int, {'0': 0}) # 分割后图片的最优角度统计
|
||||
for target_image in target_images:
|
||||
# dewarped_image = dewarp.dewarp_image(target_image) # 去扭曲
|
||||
dewarped_image = target_image
|
||||
angles = image_util.parse_rotation_angles(dewarped_image)
|
||||
|
||||
split_results = image_util.split(dewarped_image)
|
||||
split_results = image_util.split(target_image)
|
||||
for split_result in split_results:
|
||||
if split_result['img'] is None or split_result['img'].size == 0:
|
||||
continue
|
||||
rotated_img = image_util.rotate(split_result['img'], int(angles[0]))
|
||||
ie_results = [{'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[0]}]
|
||||
if not ie_results[0]['result'] or len(ie_results[0]['result']) < len(ie.kwargs.get('schema')):
|
||||
rotated_img = image_util.rotate(split_result['img'], int(angles[1]))
|
||||
ie_results.append({'result': ie_temp_image(ie, OCR, rotated_img), 'angle': angles[1]})
|
||||
ie_temp_result = ie_temp_image(ie, OCR, split_result['img'], is_screenshot)
|
||||
ocr_text += ie_temp_result[1]
|
||||
ie_results = [{'result': ie_temp_result[0], 'angle': ie_temp_result[2]}]
|
||||
now = util.get_default_datetime()
|
||||
best_angle = ['0', 0]
|
||||
for ie_result in ie_results:
|
||||
@@ -231,13 +237,15 @@ def information_extraction(ie, phrecs, identity):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
try:
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
if img_angle != '0':
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
logging.info(f'旋转图片[{phrec.cfjaddress}]替换成功,已旋转{img_angle}度。')
|
||||
# 修正旋转角度
|
||||
for zx_ie_result in zx_ie_results:
|
||||
zx_ie_result.rotation_angle -= int(img_angle)
|
||||
else:
|
||||
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
logging.info(f'高清图片[{phrec.cfjaddress}]替换成功!')
|
||||
except Exception as e:
|
||||
logging.error(f'上传图片({phrec.cfjaddress})失败', exc_info=e)
|
||||
@@ -247,8 +255,21 @@ def information_extraction(ie, phrecs, identity):
|
||||
session = MysqlSession()
|
||||
session.add_all(zx_ie_results)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
# # 添加清晰度测试
|
||||
# if better_image is None:
|
||||
# # 替换后图片默认清晰
|
||||
# clarity_result = image_util.parse_clarity(image)
|
||||
# unsharp_flag = 0 if (clarity_result[0] == 0 and clarity_result[1] >= 0.8) else 1
|
||||
# update_clarity = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
|
||||
# cfjaddress2=json.dumps(clarity_result),
|
||||
# unsharp_flag=unsharp_flag,
|
||||
# ))
|
||||
# session.execute(update_clarity)
|
||||
# session.commit()
|
||||
# session.close()
|
||||
|
||||
result['ocr_text'] = ocr_text
|
||||
return result
|
||||
|
||||
|
||||
@@ -293,7 +314,7 @@ def save_or_update_ie(table, pk_phhd, data):
|
||||
if db_data:
|
||||
# 更新
|
||||
db_data.update_time = now
|
||||
db_data.creator = HOSTNAME
|
||||
db_data.updater = HOSTNAME
|
||||
for k, v in data.items():
|
||||
setattr(db_data, k, v)
|
||||
else:
|
||||
@@ -379,8 +400,8 @@ def settlement_task(pk_phhd, settlement_list, identity):
|
||||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
|
||||
"medical_insurance_type_str": handle_original_data(
|
||||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE)),
|
||||
"admission_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
|
||||
"settlement_id": handle_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
|
||||
"admission_id": handle_admission_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
|
||||
"settlement_id": handle_settlement_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
|
||||
}
|
||||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||||
@@ -394,6 +415,10 @@ def settlement_task(pk_phhd, settlement_list, identity):
|
||||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES))
|
||||
settlement_data["medical_expenses_str"] = handle_original_data(parse_money_result[0])
|
||||
settlement_data["medical_expenses"] = parse_money_result[1]
|
||||
|
||||
if not settlement_data["settlement_id"]:
|
||||
# 如果没有结算单号就填住院号
|
||||
settlement_data["settlement_id"] = settlement_data["admission_id"]
|
||||
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
|
||||
|
||||
|
||||
@@ -408,9 +433,10 @@ def discharge_task(pk_phhd, discharge_record, identity):
|
||||
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
|
||||
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
|
||||
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
|
||||
"doctor": handle_name(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
|
||||
"admission_id": handle_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
|
||||
"doctor": handle_doctor(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
|
||||
"admission_id": handle_admission_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
|
||||
"age": handle_age(get_best_value_in_keys(discharge_record_ie_result, AGE)),
|
||||
"content": handle_text(discharge_record_ie_result['ocr_text']),
|
||||
}
|
||||
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
|
||||
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
|
||||
@@ -476,7 +502,8 @@ def cost_task(pk_phhd, cost_list, identity):
|
||||
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
|
||||
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
|
||||
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
|
||||
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES))
|
||||
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES)),
|
||||
"content": handle_text(cost_list_ie_result['ocr_text']),
|
||||
}
|
||||
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
|
||||
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
numpy==1.26.4
|
||||
onnxconverter-common==1.14.0
|
||||
aistudio_sdk==0.2.6
|
||||
onnxconverter-common==1.15.0
|
||||
onnxruntime-gpu==1.22.0
|
||||
OpenCC==1.1.6
|
||||
opencv-python==4.6.0.66
|
||||
paddle2onnx==1.2.3
|
||||
paddleclas==2.5.2
|
||||
paddlenlp==2.6.1
|
||||
paddleocr==2.7.3
|
||||
pillow==10.4.0
|
||||
paddlenlp==3.0.0b4
|
||||
paddleocr==3.1.1
|
||||
PyMuPDF==1.26.3
|
||||
pymysql==1.1.1
|
||||
requests==2.32.3
|
||||
sqlacodegen==2.3.0.post1
|
||||
sqlalchemy==1.4.52
|
||||
tenacity==8.5.0
|
||||
ufile==3.2.9
|
||||
zxing-cpp==2.2.0
|
||||
ufile==3.2.11
|
||||
zxing-cpp==2.3.0
|
||||
25
tool/batch_download.py
Normal file
25
tool/batch_download.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# 批量下载图片
|
||||
import pandas as pd
|
||||
import requests
|
||||
|
||||
from ucloud import ufile
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 读取xlsx文件
|
||||
file_path = r'D:\Echo\Downloads\Untitled.xlsx' # 将 your_file.xlsx 替换为你的文件路径
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# 循环读取 cfjaddress 列
|
||||
for address in df['cfjaddress']:
|
||||
img_url = ufile.get_private_url(address)
|
||||
# 下载图片并保存到本地
|
||||
response = requests.get(img_url)
|
||||
|
||||
# 检查请求是否成功
|
||||
if response.status_code == 200:
|
||||
# 定义保存图片的路径
|
||||
with open(f'../img/{address}', 'wb') as file:
|
||||
file.write(response.content)
|
||||
print(f"{address}下载成功")
|
||||
else:
|
||||
print(f"{address}下载失败,状态码: {response.status_code}")
|
||||
43
tool/check_clarity_model.py
Normal file
43
tool/check_clarity_model.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import pandas as pd
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhrec
|
||||
|
||||
|
||||
def check_unclarity():
|
||||
file_path = r'D:\Echo\Downloads\unclare.xlsx' # 将 your_file.xlsx 替换为你的文件路径
|
||||
df = pd.read_excel(file_path)
|
||||
|
||||
# 循环读取 pk_phrec 列
|
||||
output_file_path = 'check_unclarity_result.txt'
|
||||
with open(output_file_path, 'w') as f:
|
||||
for pk_phrec in df['pk_phrec']:
|
||||
session = MysqlSession()
|
||||
phrec = session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag).filter(
|
||||
ZxPhrec.pk_phrec == pk_phrec
|
||||
).one_or_none()
|
||||
session.close()
|
||||
if phrec and phrec.unsharp_flag == 0:
|
||||
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
|
||||
|
||||
|
||||
def check_clarity():
|
||||
file_path = r'D:\Echo\Downloads\unclare.xlsx'
|
||||
df = pd.read_excel(file_path)
|
||||
session = MysqlSession()
|
||||
phrecs = (session.query(ZxPhrec.pk_phrec, ZxPhrec.cfjaddress2, ZxPhrec.unsharp_flag)
|
||||
.filter(ZxPhrec.pk_phrec >= 30810206)
|
||||
.filter(ZxPhrec.pk_phrec <= 31782247)
|
||||
.filter(ZxPhrec.unsharp_flag == 1)
|
||||
.all())
|
||||
session.close()
|
||||
|
||||
output_file_path = 'check_clarity_result.txt'
|
||||
with open(output_file_path, 'w') as f:
|
||||
for phrec in phrecs:
|
||||
if phrec.pk_phrec not in df['pk_phrec'].to_list():
|
||||
f.write(f"{phrec.pk_phrec} {phrec.cfjaddress2}\n")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_clarity()
|
||||
17
tool/paddle2nb.py
Normal file
17
tool/paddle2nb.py
Normal file
@@ -0,0 +1,17 @@
|
||||
# paddle模型转nb格式,推荐使用paddlelite2.10
|
||||
from paddlelite.lite import Opt
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 1. 创建opt实例
|
||||
opt = Opt()
|
||||
# 2. 指定输入模型地址
|
||||
opt.set_model_dir("./model")
|
||||
# 3. 指定转化类型: arm、x86、opencl、npu
|
||||
# 一般认为arm为cpu,opencl为gpu
|
||||
opt.set_valid_places("arm")
|
||||
# 4. 指定模型转化类型: naive_buffer、protobuf
|
||||
opt.set_model_type("naive_buffer")
|
||||
# 4. 输出模型地址
|
||||
opt.set_optimize_out("model")
|
||||
# 5. 执行模型优化
|
||||
opt.run()
|
||||
21
update_dev.sh
Normal file
21
update_dev.sh
Normal file
@@ -0,0 +1,21 @@
|
||||
#!/bin/bash
|
||||
# 项目更新脚本
|
||||
echo "开始更新测试项目..."
|
||||
# 备份docker-compose配置
|
||||
cp -i docker-compose.dev.yml docker-compose-backup.dev.yml
|
||||
# 拉取最新的git
|
||||
git pull
|
||||
# 构建新镜像
|
||||
docker-compose -f docker-compose.dev.yml build
|
||||
# 停止旧的容器
|
||||
docker-compose -f docker-compose-backup.dev.yml down
|
||||
# 启动新的容器
|
||||
docker-compose -f docker-compose.dev.yml up -d
|
||||
# 删除docker-compose备份
|
||||
rm -f docker-compose-backup.dev.yml
|
||||
# 查看容器运行情况
|
||||
docker ps
|
||||
# 查看镜像
|
||||
docker images
|
||||
# 结束
|
||||
echo "测试项目更新完成,请确认容器版本正确,自行删除过期镜像。"
|
||||
@@ -2,7 +2,7 @@ import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
|
||||
from util import util
|
||||
from util import util, string_util
|
||||
|
||||
|
||||
# 处理金额类数据
|
||||
@@ -12,6 +12,7 @@ def handle_decimal(string):
|
||||
string = re.sub(r'[^0-9.]', '', string)
|
||||
if not string:
|
||||
return ""
|
||||
string = string_util.full_to_half(string)
|
||||
if "." not in string:
|
||||
if len(string) > 2:
|
||||
result = string[:-2] + "." + string[-2:]
|
||||
@@ -89,13 +90,17 @@ def handle_date(string):
|
||||
def handle_hospital(string):
|
||||
if not string:
|
||||
return ""
|
||||
return string[:255]
|
||||
# 只允许汉字、数字
|
||||
string = re.sub(r'[^⺀-鿿0-9]', '', string)
|
||||
return string[:200]
|
||||
|
||||
|
||||
def handle_department(string):
|
||||
if not string:
|
||||
return ""
|
||||
return string[:255]
|
||||
# 只允许汉字
|
||||
string = re.sub(r'[^⺀-鿿]', '', string)
|
||||
return string[:200]
|
||||
|
||||
|
||||
def parse_department(string):
|
||||
@@ -119,7 +124,13 @@ def parse_department(string):
|
||||
def handle_name(string):
|
||||
if not string:
|
||||
return ""
|
||||
return re.sub(r'[^⺀-鿿·]', '', string)[:30]
|
||||
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
|
||||
|
||||
|
||||
def handle_doctor(string):
|
||||
if not string:
|
||||
return "无"
|
||||
return re.sub(r'[^⺀-鿿·]', '', string)[:20]
|
||||
|
||||
|
||||
# 处理医保类型数据
|
||||
@@ -150,18 +161,29 @@ def handle_original_data(string):
|
||||
|
||||
|
||||
# 处理id类数据
|
||||
def handle_id(string):
|
||||
def handle_admission_id(string):
|
||||
if not string:
|
||||
return ""
|
||||
# 只允许字母和数字
|
||||
string = re.sub(r'[^0-9a-zA-Z]', '', string)
|
||||
# 防止过长存入数据库失败
|
||||
return string[:50]
|
||||
return string[:20]
|
||||
|
||||
|
||||
def handle_settlement_id(string):
|
||||
if not string:
|
||||
return ""
|
||||
# 只允许字母和数字
|
||||
string = re.sub(r'[^0-9a-zA-Z]', '', string)
|
||||
# 防止过长存入数据库失败
|
||||
return string[:30]
|
||||
|
||||
|
||||
# 处理年龄类数据
|
||||
def handle_age(string):
|
||||
if not string:
|
||||
return ""
|
||||
string = string.split("岁")[0]
|
||||
string = string_util.full_to_half(string.split("岁")[0])
|
||||
num = re.sub(r'\D', '', string)
|
||||
return num[-3:]
|
||||
|
||||
@@ -178,3 +200,9 @@ def parse_hospital(string):
|
||||
split_hospitals = string_without_company.replace("医院", "医院 ")
|
||||
result += split_hospitals.strip().split(" ")
|
||||
return result
|
||||
|
||||
|
||||
def handle_text(string):
|
||||
if not string:
|
||||
return ""
|
||||
return string[:16383]
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
import logging
|
||||
import math
|
||||
import urllib.request
|
||||
from io import BytesIO
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
from PIL import Image
|
||||
from PIL.ExifTags import TAGS
|
||||
from paddleclas import PaddleClas
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
|
||||
@@ -14,20 +17,36 @@ def read(image_path):
|
||||
"""
|
||||
从网络或本地读取图片
|
||||
:param image_path: 网络或本地路径
|
||||
:return: NumPy数组形式的图片
|
||||
:return: NumPy数组形式的图片, EXIF数据
|
||||
"""
|
||||
if image_path.startswith("http"):
|
||||
# 发送HTTP请求并获取图像数据
|
||||
resp = urllib.request.urlopen(image_path, timeout=60)
|
||||
# 将数据读取为字节流
|
||||
image_data = resp.read()
|
||||
# 将字节流转换为NumPy数组
|
||||
image_np = numpy.frombuffer(image_data, numpy.uint8)
|
||||
# 解码NumPy数组为OpenCV图像格式
|
||||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
else:
|
||||
image = cv2.imread(image_path)
|
||||
return image
|
||||
with open(image_path, "rb") as f:
|
||||
image_data = f.read()
|
||||
|
||||
# 解析EXIF信息(基于原始字节流)
|
||||
exif_data = {}
|
||||
try:
|
||||
# 用PIL打开原始字节流
|
||||
with Image.open(BytesIO(image_data)) as img:
|
||||
# 获取EXIF字典
|
||||
exif_info = img._getexif()
|
||||
if exif_info:
|
||||
# 将EXIF标签的数字ID转换为可读名称(如36867对应"DateTimeOriginal")
|
||||
for tag_id, value in exif_info.items():
|
||||
tag_name = TAGS.get(tag_id, tag_id)
|
||||
exif_data[tag_name] = value
|
||||
except Exception as e:
|
||||
logging.error("解析EXIF信息失败", exc_info=e)
|
||||
# 将字节流转换为NumPy数组
|
||||
image_np = numpy.frombuffer(image_data, numpy.uint8)
|
||||
# 解码NumPy数组为OpenCV图像格式
|
||||
image = cv2.imdecode(image_np, cv2.IMREAD_COLOR)
|
||||
return image, exif_data
|
||||
|
||||
|
||||
def capture(image, rectangle):
|
||||
@@ -61,7 +80,7 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
|
||||
"""
|
||||
split_result = []
|
||||
if isinstance(image, str):
|
||||
image = read(image)
|
||||
image, _ = read(image)
|
||||
height, width = image.shape[:2]
|
||||
hw_ratio = height / width
|
||||
wh_ratio = width / height
|
||||
@@ -72,14 +91,18 @@ def split(image, ratio=1.414, overlap=0.05, x_compensation=3):
|
||||
for i in range(math.ceil(height / step)):
|
||||
offset = round(step * i)
|
||||
cropped_img = capture(image, [0, offset, width, offset + new_img_height])
|
||||
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
|
||||
if cropped_img.shape[0] > 0:
|
||||
# 计算误差可能导致图片高度为0,此时不添加
|
||||
split_result.append({"img": cropped_img, "x_offset": 0, "y_offset": offset})
|
||||
elif wh_ratio > ratio: # 横向过长
|
||||
new_img_width = height * ratio
|
||||
step = height * (ratio - overlap * x_compensation) # 一般文字是横向的,所以横向截取时增大重叠部分
|
||||
for i in range(math.ceil(width / step)):
|
||||
offset = round(step * i)
|
||||
cropped_img = capture(image, [offset, 0, offset + new_img_width, width])
|
||||
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
|
||||
if cropped_img.shape[1] > 0:
|
||||
# 计算误差可能导致图片宽度为0,此时不添加
|
||||
split_result.append({"img": cropped_img, "x_offset": offset, "y_offset": 0})
|
||||
else:
|
||||
split_result.append({"img": image, "x_offset": 0, "y_offset": 0})
|
||||
return split_result
|
||||
@@ -247,3 +270,78 @@ def combined(img1, img2):
|
||||
combined_img[:height1, :width1] = img1
|
||||
combined_img[:height2, width1:width1 + width2] = img2
|
||||
return combined_img
|
||||
|
||||
|
||||
def parse_clarity(image):
|
||||
"""
|
||||
判断图片清晰度
|
||||
:param image: 图片NumPy数组或文件路径
|
||||
:return: 判断结果及置信度
|
||||
"""
|
||||
clarity_result = [1, 0]
|
||||
model = PaddleClas(inference_model_dir=r"model/clas/clarity_assessment", use_gpu=True)
|
||||
clas_result = model.predict(input_data=image)
|
||||
try:
|
||||
clas_result = next(clas_result)[0]
|
||||
clarity_result = [clas_result["class_ids"][0], clas_result["scores"][0]]
|
||||
except Exception as e:
|
||||
logging.error("获取图片清晰度失败", exc_info=e)
|
||||
return clarity_result
|
||||
|
||||
|
||||
def is_photo_by_exif(exif_tags):
|
||||
"""分析EXIF数据判断是否为照片"""
|
||||
# 照片通常包含的EXIF标签
|
||||
photo_tags = [
|
||||
'FNumber', # 光圈
|
||||
'ExposureTime', # 曝光时间
|
||||
'ISOSpeedRatings', # ISO
|
||||
'FocalLength', # 焦距
|
||||
'LensModel', # 镜头型号
|
||||
'GPSLatitude' # GPS位置信息
|
||||
]
|
||||
|
||||
# 统计照片相关的EXIF标签数量
|
||||
photo_tag_count = 0
|
||||
if exif_tags:
|
||||
for tag in photo_tags:
|
||||
if tag in exif_tags:
|
||||
photo_tag_count += 1
|
||||
# 如果有2个以上照片相关的EXIF标签,倾向于是照片
|
||||
if photo_tag_count >= 2:
|
||||
return True
|
||||
# 不确定是照片返回False
|
||||
return False
|
||||
|
||||
|
||||
def is_screenshot_by_image_features(image):
|
||||
"""分析图像特征判断是否为截图"""
|
||||
# 定义边缘像素标准差阈值,小于此阈值则认为图片是截图
|
||||
edge_std_threshold = 20.0
|
||||
try:
|
||||
# 检查边缘像素的一致性(截图边缘通常更整齐)
|
||||
edge_pixels = []
|
||||
# 取图像边缘10像素
|
||||
edge_pixels.extend(image[:10, :].flatten()) # 顶部边缘
|
||||
edge_pixels.extend(image[-10:, :].flatten()) # 底部边缘
|
||||
edge_pixels.extend(image[:, :10].flatten()) # 左侧边缘
|
||||
edge_pixels.extend(image[:, -10:].flatten()) # 右侧边缘
|
||||
|
||||
# 计算边缘像素的标准差(值越小说明越一致)
|
||||
edge_std = numpy.std(edge_pixels)
|
||||
logging.info(f"边缘像素标准差: {edge_std}")
|
||||
return edge_std < edge_std_threshold
|
||||
except Exception as e:
|
||||
logging.error("图像特征分析失败", exc_info=e)
|
||||
return False
|
||||
|
||||
|
||||
def is_screenshot(image, exif_tags):
|
||||
"""综合判断是否是截图"""
|
||||
# 先检查EXIF数据
|
||||
result_of_exif = is_photo_by_exif(exif_tags)
|
||||
# 如果有明显的照片EXIF信息,直接判断为照片
|
||||
if result_of_exif:
|
||||
return False
|
||||
# 分析图像特征
|
||||
return is_screenshot_by_image_features(image)
|
||||
|
||||
@@ -1,3 +1,6 @@
|
||||
import unicodedata
|
||||
|
||||
|
||||
def blank(string):
|
||||
"""
|
||||
判断字符串是否为空或者纯空格
|
||||
@@ -5,3 +8,23 @@ def blank(string):
|
||||
:return: 字符串是否为空或者纯空格
|
||||
"""
|
||||
return not string or string.isspace()
|
||||
|
||||
|
||||
def full_to_half(string):
|
||||
"""
|
||||
全角转半角
|
||||
:param string: 字符串
|
||||
:return: 半角字符串
|
||||
"""
|
||||
if not isinstance(string, str):
|
||||
raise TypeError("全角转半角的输入必须是字符串类型")
|
||||
|
||||
if not string:
|
||||
return string
|
||||
|
||||
half_string = ''.join([
|
||||
unicodedata.normalize('NFKC', char) if unicodedata.east_asian_width(char) in ['F', 'W'] else char
|
||||
for char in string
|
||||
])
|
||||
|
||||
return half_string
|
||||
|
||||
19
util/util.py
19
util/util.py
@@ -12,9 +12,10 @@ def get_default_datetime():
|
||||
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
|
||||
def get_ocr_layout(ocr, img_path):
|
||||
def get_ocr_layout(ocr, img_path, is_screenshot=False):
|
||||
"""
|
||||
获取ocr识别的结果,转为合适的layout形式
|
||||
:param is_screenshot: 是否是截图
|
||||
:param ocr: ocr模型
|
||||
:param img_path: 图片本地路径
|
||||
:return:
|
||||
@@ -36,18 +37,18 @@ def get_ocr_layout(ocr, img_path):
|
||||
return True
|
||||
|
||||
layout = []
|
||||
ocr_result = ocr.ocr(img_path, cls=False)
|
||||
ocr_result = ocr_result[0]
|
||||
ocr_result = ocr.predict(input=img_path, use_doc_orientation_classify=not is_screenshot, use_doc_unwarping=not is_screenshot)
|
||||
ocr_result = next(ocr_result)
|
||||
if not ocr_result:
|
||||
return layout
|
||||
for segment in ocr_result:
|
||||
box = segment[0]
|
||||
return layout, "0"
|
||||
angle = ocr_result.get("doc_preprocessor_res", {}).get("angle", "0")
|
||||
for i in range(len(ocr_result.get('rec_texts'))):
|
||||
box = ocr_result.get("rec_polys")[i].tolist()
|
||||
box = _get_box(box)
|
||||
if not _normal_box(box):
|
||||
continue
|
||||
text = segment[1][0]
|
||||
layout.append((box, text))
|
||||
return layout
|
||||
layout.append((box, ocr_result.get("rec_texts")[i]))
|
||||
return layout, str(angle)
|
||||
|
||||
|
||||
def delete_temp_file(temp_files):
|
||||
|
||||
@@ -24,7 +24,7 @@ def write_visual_result(image, angle=0, layout=None, result=None):
|
||||
img_name = img[:last_dot_index]
|
||||
img_type = img[last_dot_index + 1:]
|
||||
|
||||
img_array = image_util.read(image)
|
||||
img_array, _ = image_util.read(image)
|
||||
if angle != 0:
|
||||
img_array = image_util.rotate(img_array, angle)
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
@@ -63,7 +63,7 @@ def visual_model_test(model_type, test_img, task_path, schema):
|
||||
img["y_offset"] -= offset_y
|
||||
|
||||
temp_files_paths.append(temp_file.name)
|
||||
parsed_doc = util.get_ocr_layout(
|
||||
parsed_doc, _ = util.get_ocr_layout(
|
||||
PaddleOCR(det_db_box_thresh=0.3, det_db_thresh=0.1, det_limit_side_len=1248, drop_score=0.3,
|
||||
save_crop_res=False),
|
||||
temp_file.name)
|
||||
|
||||
Reference in New Issue
Block a user