Compare commits
124 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 3044f1fc08 | |||
| da18d890f7 | |||
| ae52d0594e | |||
| d51e56b8f2 | |||
| 83339b5e58 | |||
| 727743d20e | |||
| 814730a0f0 | |||
| c9894d257e | |||
| 68043e5773 | |||
| fe58bb3bfa | |||
| ce44a81603 | |||
| 4b90bf6dfa | |||
| 248e49bf4b | |||
| 401954dca0 | |||
| 6529dc3d98 | |||
| 3f93bd476a | |||
| d85b3fff8f | |||
| 61a7802674 | |||
| 9556da47e9 | |||
| 3710450221 | |||
| 27a4395ca0 | |||
| f116798c30 | |||
| 8c47beb00c | |||
| 74920869e7 | |||
| 9d0db073d6 | |||
| 8e7745f1f6 | |||
| cc53243647 | |||
| 39da0d8a00 | |||
| a2e1f10261 | |||
| e1bd9f3786 | |||
| 46f295d422 | |||
| 1a0caf30d0 | |||
| 25df420be8 | |||
| b5dffaf5bd | |||
| 0e4cfd10b6 | |||
| f98969d957 | |||
| 0c9bed8661 | |||
| d0b4a77817 | |||
| 00e5ca7c30 | |||
| 5dee4ed568 | |||
| 06869e691f | |||
| 8e06fdafa0 | |||
| 84d106c7de | |||
| 9c41fab95c | |||
| 0060c4ad59 | |||
| d374e0743a | |||
| 947b4f20f3 | |||
| 445d57e8c6 | |||
| b09f16fe23 | |||
| c28fc62d3f | |||
| b332aa00dd | |||
| 5af6256376 | |||
| 15ea3ff96f | |||
| 19237d3a3c | |||
| 0b0882d456 | |||
| 304f6897f0 | |||
| a9f172fdb0 | |||
| ac4e4ff8f8 | |||
| f7fbe709bf | |||
| 396550058f | |||
| b9ac638b38 | |||
| 894cab4f0b | |||
| bb6d9c3b47 | |||
| f8280e87ee | |||
| 608a647621 | |||
| 7b9d9ca589 | |||
| d9b24e906d | |||
| 97c7b2cfce | |||
| 004dd12004 | |||
| cc9d020008 | |||
| 7335553080 | |||
| ebb10b2816 | |||
| 98fb9fa861 | |||
| c75415164e | |||
| 03d8652b8f | |||
| e3be5cf4b2 | |||
| c92b549480 | |||
| d36740d729 | |||
| a1dea6f29c | |||
| 0fc0c80d6f | |||
| f3930cc7bd | |||
| a11cefb999 | |||
| 5c0fc0f819 | |||
| 77010f0598 | |||
| e4b58e30c0 | |||
| 15fe5d4f0d | |||
| fc69aa5b9d | |||
| 795134f566 | |||
| a3fa1e502e | |||
| 7a4cb5263a | |||
| 46be9a26be | |||
| f1149854ce | |||
| 117b29a737 | |||
| 3219f28934 | |||
| 2e1c0a57c7 | |||
| 2dcd2d2a34 | |||
| 153eb70f84 | |||
| b5aba0418b | |||
| 603b027ca6 | |||
| d4c54b04f5 | |||
| fc3e7b4ed4 | |||
| a62c2af816 | |||
| 0618754da2 | |||
| c5a03ad16f | |||
| ff9d612e67 | |||
| 86d28096d4 | |||
| 87180cd282 | |||
| f13ffd1fe9 | |||
| 09f62b36a9 | |||
| 186cab0317 | |||
| 101b2126f4 | |||
| d4a695e9ea | |||
| 72794f699e | |||
| 3189caf7aa | |||
| b8c1202957 | |||
| 7647df7d74 | |||
| 3438cf6e0e | |||
| 90a6d5ec75 | |||
| 9c21152823 | |||
| c091a82a91 | |||
| a2a82df21c | |||
| f0c03e763b | |||
| 7b6e78373c | |||
| 65b7126348 |
@@ -238,8 +238,11 @@ cython_debug/
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# 模型通过卷绑定挂载到容器中
|
||||
/model
|
||||
# 通过卷绑定挂载到容器中
|
||||
/log
|
||||
/services/paddle_services/log
|
||||
/services/paddle_services/model
|
||||
/tmp_img
|
||||
# docker
|
||||
Dockerfile
|
||||
docker-compose*.yml
|
||||
8
.gitignore
vendored
8
.gitignore
vendored
@@ -142,7 +142,11 @@ cython_debug/
|
||||
.idea
|
||||
|
||||
### Model
|
||||
model
|
||||
services/paddle_services/model
|
||||
|
||||
### Log Backups
|
||||
*.log.*-*-*
|
||||
*.log.*-*-*
|
||||
|
||||
### Tmp Files
|
||||
/tmp_img
|
||||
/test_img
|
||||
12
Dockerfile
12
Dockerfile
@@ -1,5 +1,5 @@
|
||||
# 使用官方的paddle镜像作为基础
|
||||
FROM registry.baidubce.com/paddlepaddle/paddle:2.6.1-gpu-cuda12.0-cudnn8.9-trt8.6
|
||||
# 使用官方的python镜像作为基础
|
||||
FROM python:3.10.15-bookworm
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
@@ -13,12 +13,10 @@ ENV PYTHONUNBUFFERED=1 \
|
||||
|
||||
# 安装依赖
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
COPY packages /app/packages
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
|
||||
&& python3 -m pip install --upgrade pip \
|
||||
&& pip install --no-cache-dir -r requirements.txt \
|
||||
&& pip uninstall -y onnxruntime onnxruntime-gpu \
|
||||
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
&& sed -i 's|deb.debian.org|mirrors.aliyun.com|g' /etc/apt/sources.list.d/debian.sources \
|
||||
&& apt-get update && apt-get install libgl1 -y \
|
||||
&& pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 将当前目录内容复制到容器的/app内
|
||||
COPY . /app
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
|
||||
1. 从Git远程仓库克隆项目到本地。
|
||||
|
||||
2. 将深度学习模型复制到./model目录下,具体请看[模型更新](#模型更新)部分。
|
||||
2. 将深度学习模型复制到./services/paddle_services/model目录下,具体请看[模型更新](#模型更新)部分。
|
||||
|
||||
3. 安装docker和docker-compose。
|
||||
|
||||
@@ -126,8 +126,5 @@ bash update.sh
|
||||
2. 新增扭曲矫正功能
|
||||
21. 版本号:1.14.0
|
||||
1. 新增二维码识别替换高清图片功能
|
||||
22. 版本号:1.15.0
|
||||
1. 新增图片清晰度测试
|
||||
23. 版本号:1.16.0
|
||||
1. 优化结算单号规则
|
||||
2. 新增判断截图方法
|
||||
22. 版本号:2.0.0
|
||||
1. 项目架构调整,模型全部采用接口调用
|
||||
4
api_test.py
Normal file
4
api_test.py
Normal file
@@ -0,0 +1,4 @@
|
||||
import time
|
||||
|
||||
if __name__ == '__main__':
|
||||
time.sleep(3600)
|
||||
@@ -1,15 +0,0 @@
|
||||
# 自动生成数据库表和sqlalchemy对应的Model
|
||||
import subprocess
|
||||
|
||||
from db import DB_URL
|
||||
|
||||
if __name__ == '__main__':
|
||||
table = input("请输入表名:")
|
||||
out_file = f"db/{table}.py"
|
||||
command = f"sqlacodegen {DB_URL} --outfile={out_file} --tables={table}"
|
||||
|
||||
try:
|
||||
subprocess.run(command, shell=True, check=True)
|
||||
print(f"{table}.py文件生成成功!请检查并复制到合适的文件中!")
|
||||
except Exception as e:
|
||||
print(f"生成{table}.py文件时发生错误: {e}")
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,102 +0,0 @@
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from io import BytesIO
|
||||
from itertools import groupby
|
||||
|
||||
import requests
|
||||
from PIL import ImageDraw, Image, ImageFont
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxIeResult, ZxPhrec
|
||||
from ucloud import ufile
|
||||
from util import image_util
|
||||
|
||||
|
||||
def check_ie_result(pk_phhd):
|
||||
os.makedirs(f"./check_result/{pk_phhd}", exist_ok=True)
|
||||
json_result = {"pk_phhd": pk_phhd}
|
||||
session = MysqlSession()
|
||||
phhd = session.query(ZxPhhd.cXm).filter(ZxPhhd.pk_phhd == pk_phhd).one()
|
||||
json_result["cXm"] = phhd.cXm
|
||||
settlement = (session.query(ZxIeSettlement.pk_ie_settlement, ZxIeSettlement.name, ZxIeSettlement.admission_date,
|
||||
ZxIeSettlement.discharge_date, ZxIeSettlement.medical_expenses,
|
||||
ZxIeSettlement.personal_cash_payment, ZxIeSettlement.personal_account_payment,
|
||||
ZxIeSettlement.personal_funded_amount, ZxIeSettlement.medical_insurance_type,
|
||||
ZxIeSettlement.admission_id, ZxIeSettlement.settlement_id)
|
||||
.filter(ZxIeSettlement.pk_phhd == pk_phhd).one())
|
||||
settlement_result = settlement._asdict()
|
||||
json_result["settlement"] = settlement_result
|
||||
|
||||
discharge = (session.query(ZxIeDischarge.pk_ie_discharge, ZxIeDischarge.hospital, ZxIeDischarge.pk_yljg,
|
||||
ZxIeDischarge.department, ZxIeDischarge.pk_ylks, ZxIeDischarge.name, ZxIeDischarge.age,
|
||||
ZxIeDischarge.admission_date, ZxIeDischarge.discharge_date, ZxIeDischarge.doctor,
|
||||
ZxIeDischarge.admission_id)
|
||||
.filter(ZxIeDischarge.pk_phhd == pk_phhd).one())
|
||||
discharge_result = discharge._asdict()
|
||||
json_result["discharge"] = discharge_result
|
||||
|
||||
cost = session.query(ZxIeCost.pk_ie_cost, ZxIeCost.name, ZxIeCost.admission_date, ZxIeCost.discharge_date,
|
||||
ZxIeCost.medical_expenses).filter(ZxIeCost.pk_phhd == pk_phhd).one()
|
||||
cost_result = cost._asdict()
|
||||
json_result["cost"] = cost_result
|
||||
|
||||
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
|
||||
ZxPhrec.pk_phhd == pk_phhd).all()
|
||||
for phrec in phrecs:
|
||||
img_name = phrec.cfjaddress
|
||||
img_path = ufile.get_private_url(img_name)
|
||||
|
||||
response = requests.get(img_path)
|
||||
image = Image.open(BytesIO(response.content)).convert("RGB")
|
||||
font_size = image.width * image.height / 200000
|
||||
font = ImageFont.truetype("./font/simfang.ttf", size=font_size)
|
||||
|
||||
ocr = session.query(ZxIeResult.id, ZxIeResult.content, ZxIeResult.rotation_angle, ZxIeResult.x_offset,
|
||||
ZxIeResult.y_offset).filter(ZxIeResult.pk_phrec == phrec.pk_phrec).all()
|
||||
if not ocr:
|
||||
os.makedirs(f"./check_result/{pk_phhd}/0", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/0/{img_name}")
|
||||
|
||||
for _, group_results in groupby(ocr, key=lambda x: x.id):
|
||||
draw = ImageDraw.Draw(image)
|
||||
for ocr_item in group_results:
|
||||
result = json.loads(ocr_item.content)
|
||||
rotation_angle = ocr_item.rotation_angle
|
||||
x_offset = ocr_item.x_offset
|
||||
y_offset = ocr_item.y_offset
|
||||
for key in result:
|
||||
for value in result[key]:
|
||||
box = value["bbox"][0]
|
||||
|
||||
if rotation_angle:
|
||||
box = image_util.invert_rotate_rectangle(box, (image.width / 2, image.height / 2),
|
||||
rotation_angle)
|
||||
if x_offset:
|
||||
box[0] += x_offset
|
||||
box[2] += x_offset
|
||||
if y_offset:
|
||||
box[1] += y_offset
|
||||
box[3] += y_offset
|
||||
|
||||
draw.rectangle(box, outline="red", width=2) # 绘制矩形
|
||||
draw.text((box[0], box[1] - font_size), key, fill="blue", font=font) # 在矩形上方绘制文本
|
||||
draw.text((box[0], box[3]), value["text"], fill="blue", font=font) # 在矩形下方绘制文本
|
||||
os.makedirs(f"./check_result/{pk_phhd}/{ocr_item.id}", exist_ok=True)
|
||||
image.save(f"./check_result/{pk_phhd}/{ocr_item.id}/{img_name}")
|
||||
session.close()
|
||||
|
||||
# 自定义JSON处理器
|
||||
def default(obj):
|
||||
if isinstance(obj, Decimal):
|
||||
return float(obj)
|
||||
if isinstance(obj, datetime.date):
|
||||
return obj.strftime("%Y-%m-%d")
|
||||
|
||||
with open(f"./check_result/{pk_phhd}/result.json", "w", encoding="utf-8") as json_file:
|
||||
json.dump(json_result, json_file, indent=4, ensure_ascii=False, default=default)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
check_ie_result(0)
|
||||
@@ -19,5 +19,5 @@ DB_URL = f'mysql+pymysql://{USERNAME}:{PASSWORD}@{HOSTNAME}:{PORT}/{DATABASE}'
|
||||
SHOW_SQL = False
|
||||
|
||||
Engine = create_engine(DB_URL, echo=SHOW_SQL)
|
||||
Base = declarative_base()
|
||||
Base = declarative_base(Engine)
|
||||
MysqlSession = sessionmaker(bind=Engine)
|
||||
|
||||
35
db/mysql.py
35
db/mysql.py
@@ -1,5 +1,5 @@
|
||||
# coding: utf-8
|
||||
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary, Text
|
||||
from sqlalchemy import Column, DECIMAL, Date, DateTime, Index, String, text, LargeBinary
|
||||
from sqlalchemy.dialects.mysql import BIT, CHAR, INTEGER, TINYINT, VARCHAR
|
||||
|
||||
from db import Base
|
||||
@@ -56,8 +56,7 @@ class ZxIeCost(Base):
|
||||
|
||||
pk_ie_cost = Column(INTEGER(11), primary_key=True, comment='费用明细信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
content = Column(Text, comment='详细内容')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
@@ -97,19 +96,19 @@ class ZxIeDischarge(Base):
|
||||
|
||||
pk_ie_discharge = Column(INTEGER(11), primary_key=True, comment='出院记录信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
content = Column(Text, comment='详细内容')
|
||||
hospital = Column(String(200), comment='医院')
|
||||
content = Column(String(5000), comment='详细内容')
|
||||
hospital = Column(String(255), comment='医院')
|
||||
pk_yljg = Column(INTEGER(11), comment='医院主键')
|
||||
department = Column(String(200), comment='科室')
|
||||
department = Column(String(255), comment='科室')
|
||||
pk_ylks = Column(INTEGER(11), comment='科室主键')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
age = Column(INTEGER(3), comment='年龄')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
discharge_date = Column(Date, comment='出院日期')
|
||||
doctor = Column(String(20), comment='主治医生')
|
||||
admission_id = Column(String(20), comment='住院号')
|
||||
doctor = Column(String(30), comment='主治医生')
|
||||
admission_id = Column(String(50), comment='住院号')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
@@ -123,7 +122,7 @@ class ZxIeResult(Base):
|
||||
pk_ocr = Column(INTEGER(11), primary_key=True, comment='图片OCR识别主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, comment='报销单主键')
|
||||
pk_phrec = Column(INTEGER(11), nullable=False, comment='图片主键')
|
||||
id = Column(INTEGER(11), nullable=False, comment='识别批次')
|
||||
id = Column(CHAR(32), nullable=False, comment='识别批次')
|
||||
cfjaddress = Column(String(200), nullable=False, comment='云存储文件名')
|
||||
content = Column(String(5000), comment='OCR识别内容')
|
||||
rotation_angle = Column(INTEGER(11), comment='旋转角度')
|
||||
@@ -141,7 +140,7 @@ class ZxIeSettlement(Base):
|
||||
|
||||
pk_ie_settlement = Column(INTEGER(11), primary_key=True, comment='结算清单信息抽取主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, unique=True, comment='报销案子主键')
|
||||
name = Column(String(20), comment='患者姓名')
|
||||
name = Column(String(30), comment='患者姓名')
|
||||
admission_date_str = Column(String(255), comment='入院日期字符串')
|
||||
admission_date = Column(Date, comment='入院日期')
|
||||
discharge_date_str = Column(String(255), comment='出院日期字符串')
|
||||
@@ -155,9 +154,9 @@ class ZxIeSettlement(Base):
|
||||
personal_funded_amount_str = Column(String(255), comment='自费金额字符串')
|
||||
personal_funded_amount = Column(DECIMAL(18, 2), comment='自费金额')
|
||||
medical_insurance_type_str = Column(String(255), comment='医保类型字符串')
|
||||
medical_insurance_type = Column(String(10), comment='医保类型')
|
||||
admission_id = Column(String(20), comment='住院号')
|
||||
settlement_id = Column(String(30), comment='医保结算单号码')
|
||||
medical_insurance_type = Column(String(40), comment='医保类型')
|
||||
admission_id = Column(String(50), comment='住院号')
|
||||
settlement_id = Column(String(50), comment='医保结算单号码')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
@@ -415,17 +414,19 @@ class ZxIeReview(Base):
|
||||
pk_ie_review = Column(INTEGER(11), primary_key=True, comment='自动审核主键')
|
||||
pk_phhd = Column(INTEGER(11), nullable=False, comment='报销案子主键')
|
||||
success = Column(BIT(1))
|
||||
integrity = Column(BIT(1))
|
||||
has_settlement = Column(BIT(1))
|
||||
has_discharge = Column(BIT(1))
|
||||
has_cost = Column(BIT(1))
|
||||
full_page = Column(BIT(1))
|
||||
page_description = Column(String(255), comment='具体缺页描述')
|
||||
consistency = Column(BIT(1), comment='三项资料一致性。0:不一致;1:一致')
|
||||
name_match = Column(CHAR(1), server_default=text("'0'"),
|
||||
comment='三项资料姓名是否一致。0:不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致;5:与报销申请对象不一致')
|
||||
comment='三项资料姓名是否一致。0:都不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致;5:与报销申请对象不一致')
|
||||
admission_date_match = Column(CHAR(1), server_default=text("'0'"),
|
||||
comment='三项资料入院日期是否一致。0:不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致')
|
||||
comment='三项资料入院日期是否一致。0:都不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致')
|
||||
discharge_date_match = Column(CHAR(1), server_default=text("'0'"),
|
||||
comment='三项资料出院日期是否一致。0:不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致')
|
||||
comment='三项资料出院日期是否一致。0:都不一致;1:一致;2:结算单不一致;3:出院记录不一致;4:费用清单不一致')
|
||||
create_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP"), comment='创建时间')
|
||||
creator = Column(String(255), comment='创建人')
|
||||
update_time = Column(DateTime, server_default=text("CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP"),
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
# 删除本地数据库中的过期数据
|
||||
import logging.config
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhhd, ZxIeCost, ZxIeDischarge, ZxIeResult, ZxIeSettlement, ZxPhrec
|
||||
from log import LOGGING_CONFIG
|
||||
|
||||
# 过期时间(不建议小于1个月)
|
||||
EXPIRATION_DAYS = 183
|
||||
# 批量删除数量(最好在1000~10000之间)
|
||||
BATCH_SIZE = 5000
|
||||
# 数据库会话对象
|
||||
session = None
|
||||
|
||||
|
||||
def batch_delete_by_pk_phhd(model, pk_phhds):
|
||||
"""
|
||||
批量删除指定模型中主键在指定列表中的数据
|
||||
|
||||
参数:
|
||||
model:SQLAlchemy模型类(对应数据库表)
|
||||
pk_phhds:待删除的主键值列表
|
||||
|
||||
返回:
|
||||
删除的记录数量
|
||||
"""
|
||||
delete_count = (
|
||||
session.query(model)
|
||||
.filter(model.pk_phhd.in_(pk_phhds))
|
||||
.delete(synchronize_session=False)
|
||||
)
|
||||
session.commit()
|
||||
logging.getLogger("sql").info(f"{model.__tablename__}成功删除{delete_count}条数据")
|
||||
return delete_count
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
deadline = datetime.now() - timedelta(days=EXPIRATION_DAYS)
|
||||
double_deadline = deadline - timedelta(days=EXPIRATION_DAYS)
|
||||
session = MysqlSession()
|
||||
try:
|
||||
while 1:
|
||||
# 已经走完所有流程的案子,超过过期时间后删除
|
||||
phhds = (session.query(ZxPhhd.pk_phhd)
|
||||
.filter(ZxPhhd.paint_flag == "9")
|
||||
.filter(ZxPhhd.billdate < deadline)
|
||||
.limit(BATCH_SIZE)
|
||||
.all())
|
||||
if not phhds or len(phhds) <= 0:
|
||||
# 没有通过审核,可能会重拍补拍上传的案子,超过两倍过期时间后删除
|
||||
phhds = (session.query(ZxPhhd.pk_phhd)
|
||||
.filter(ZxPhhd.exsuccess_flag == "9")
|
||||
.filter(ZxPhhd.paint_flag == "0")
|
||||
.filter(ZxPhhd.billdate < double_deadline)
|
||||
.limit(BATCH_SIZE)
|
||||
.all())
|
||||
if not phhds or len(phhds) <= 0:
|
||||
# 没有符合条件的数据,退出循环
|
||||
break
|
||||
pk_phhd_values = [phhd.pk_phhd for phhd in phhds]
|
||||
logging.getLogger("sql").info(f"过期的pk_phhd有{','.join(map(str, pk_phhd_values))}")
|
||||
batch_delete_by_pk_phhd(ZxPhrec, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeResult, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeSettlement, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeDischarge, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxIeCost, pk_phhd_values)
|
||||
batch_delete_by_pk_phhd(ZxPhhd, pk_phhd_values)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logging.getLogger('error').error('过期数据删除失败!', exc_info=e)
|
||||
finally:
|
||||
session.close()
|
||||
32
det_api.py
32
det_api.py
@@ -1,32 +0,0 @@
|
||||
import base64
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from flask import Flask, request, jsonify
|
||||
|
||||
from paddle_detection import detector
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route("/det/detect_books", methods=['POST'])
|
||||
def detect_books():
|
||||
try:
|
||||
file = request.files['image']
|
||||
image_data = file.read()
|
||||
nparr = np.frombuffer(image_data, np.uint8)
|
||||
image = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
result = detector.get_book_areas(image)
|
||||
encoded_images = []
|
||||
for i in result:
|
||||
_, encoded_image = cv2.imencode('.jpg', i)
|
||||
byte_stream = encoded_image.tobytes()
|
||||
img_str = base64.b64encode(byte_stream).decode('utf-8')
|
||||
encoded_images.append(img_str)
|
||||
return jsonify(encoded_images), 200
|
||||
except Exception as e:
|
||||
return jsonify({'error': str(e)}), 500
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.run("0.0.0.0")
|
||||
@@ -1,46 +1,150 @@
|
||||
x-env:
|
||||
&template
|
||||
image: fcb_photo_review:1.16.0
|
||||
x-base:
|
||||
&base_template
|
||||
restart: always
|
||||
|
||||
x-review:
|
||||
&review_template
|
||||
<<: *template
|
||||
x-project:
|
||||
&project_template
|
||||
<<: *base_template
|
||||
image: fcb_photo_review:2.0.0
|
||||
volumes:
|
||||
- ./log:/app/log
|
||||
- ./model:/app/model
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0', '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
- ./tmp_img:/app/tmp_img
|
||||
|
||||
x-mask:
|
||||
&mask_template
|
||||
<<: *template
|
||||
x-paddle:
|
||||
&paddle_template
|
||||
<<: *base_template
|
||||
image: fcb_paddle:0.0.1
|
||||
volumes:
|
||||
- ./log:/app/log
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
- ./services/paddle_services/log:/app/log
|
||||
- ./services/paddle_services/model:/app/model
|
||||
- ./tmp_img:/app/tmp_img
|
||||
|
||||
services:
|
||||
ocr:
|
||||
<<: *paddle_template
|
||||
build:
|
||||
context: ./services/paddle_services
|
||||
container_name: ocr
|
||||
hostname: ocr
|
||||
command: [ '-w', '4', 'ocr:app', '--bind', '0.0.0.0:5001' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
ie_settlement:
|
||||
<<: *paddle_template
|
||||
container_name: ie_settlement
|
||||
hostname: ie_settlement
|
||||
command: [ '-w', '5', 'ie_settlement:app', '--bind', '0.0.0.0:5002' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
ie_discharge:
|
||||
<<: *paddle_template
|
||||
container_name: ie_discharge
|
||||
hostname: ie_discharge
|
||||
command: [ '-w', '5', 'ie_discharge:app', '--bind', '0.0.0.0:5003' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
ie_cost:
|
||||
<<: *paddle_template
|
||||
container_name: ie_cost
|
||||
hostname: ie_cost
|
||||
command: [ '-w', '5', 'ie_cost:app', '--bind', '0.0.0.0:5004' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
clas_orientation:
|
||||
<<: *paddle_template
|
||||
container_name: clas_orientation
|
||||
hostname: clas_orientation
|
||||
command: [ '-w', '3', 'clas_orientation:app', '--bind', '0.0.0.0:5005' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
det_book:
|
||||
<<: *paddle_template
|
||||
container_name: det_book
|
||||
hostname: det_book
|
||||
command: [ '-w', '4', 'det_book:app', '--bind', '0.0.0.0:5006' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '1' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
dewarp:
|
||||
<<: *paddle_template
|
||||
container_name: dewarp
|
||||
hostname: dewarp
|
||||
command: [ '-w', '4', 'dewarp:app', '--bind', '0.0.0.0:5007' ]
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- device_ids: [ '0' ]
|
||||
capabilities: [ 'gpu' ]
|
||||
driver: 'nvidia'
|
||||
|
||||
# clas_text:
|
||||
# <<: *paddle_template
|
||||
# container_name: clas_text
|
||||
# hostname: clas_text
|
||||
# command: [ '-w', '3', 'clas_text:app', '--bind', '0.0.0.0:5008' ]
|
||||
# deploy:
|
||||
# resources:
|
||||
# reservations:
|
||||
# devices:
|
||||
# - device_ids: [ '1' ]
|
||||
# capabilities: [ 'gpu' ]
|
||||
# driver: 'nvidia'
|
||||
|
||||
photo_review_1:
|
||||
<<: *review_template
|
||||
<<: *project_template
|
||||
build:
|
||||
context: .
|
||||
container_name: photo_review_1
|
||||
hostname: photo_review_1
|
||||
depends_on:
|
||||
- ocr
|
||||
- ie_settlement
|
||||
- ie_discharge
|
||||
- ie_cost
|
||||
- clas_orientation
|
||||
- det_book
|
||||
- dewarp
|
||||
# - clas_text
|
||||
command: [ 'photo_review.py', '--clean', 'True' ]
|
||||
|
||||
photo_review_2:
|
||||
<<: *review_template
|
||||
<<: *project_template
|
||||
container_name: photo_review_2
|
||||
hostname: photo_review_2
|
||||
depends_on:
|
||||
@@ -48,57 +152,41 @@ services:
|
||||
command: [ 'photo_review.py' ]
|
||||
|
||||
photo_review_3:
|
||||
<<: *review_template
|
||||
<<: *project_template
|
||||
container_name: photo_review_3
|
||||
hostname: photo_review_3
|
||||
depends_on:
|
||||
- photo_review_2
|
||||
- photo_review_1
|
||||
command: [ 'photo_review.py' ]
|
||||
|
||||
photo_review_4:
|
||||
<<: *review_template
|
||||
<<: *project_template
|
||||
container_name: photo_review_4
|
||||
hostname: photo_review_4
|
||||
depends_on:
|
||||
- photo_review_3
|
||||
- photo_review_1
|
||||
command: [ 'photo_review.py' ]
|
||||
|
||||
photo_review_5:
|
||||
<<: *review_template
|
||||
<<: *project_template
|
||||
container_name: photo_review_5
|
||||
hostname: photo_review_5
|
||||
depends_on:
|
||||
- photo_review_4
|
||||
- photo_review_1
|
||||
command: [ 'photo_review.py' ]
|
||||
|
||||
photo_mask_1:
|
||||
<<: *mask_template
|
||||
<<: *project_template
|
||||
container_name: photo_mask_1
|
||||
hostname: photo_mask_1
|
||||
depends_on:
|
||||
- photo_review_5
|
||||
- photo_review_1
|
||||
command: [ 'photo_mask.py', '--clean', 'True' ]
|
||||
|
||||
photo_mask_2:
|
||||
<<: *mask_template
|
||||
<<: *project_template
|
||||
container_name: photo_mask_2
|
||||
hostname: photo_mask_2
|
||||
depends_on:
|
||||
- photo_mask_1
|
||||
command: [ 'photo_mask.py' ]
|
||||
#
|
||||
# photo_review_6:
|
||||
# <<: *review_template
|
||||
# container_name: photo_review_6
|
||||
# hostname: photo_review_6
|
||||
# depends_on:
|
||||
# - photo_mask_2
|
||||
# command: [ 'photo_review.py' ]
|
||||
#
|
||||
# photo_review_7:
|
||||
# <<: *review_template
|
||||
# container_name: photo_review_7
|
||||
# hostname: photo_review_7
|
||||
# depends_on:
|
||||
# - photo_review_6
|
||||
# command: [ 'photo_review.py' ]
|
||||
command: [ 'photo_mask.py' ]
|
||||
@@ -1,14 +1,15 @@
|
||||
import os
|
||||
import socket
|
||||
|
||||
# 项目根目录
|
||||
PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
# 获取主机名,方便区分容器
|
||||
HOSTNAME = socket.gethostname()
|
||||
# 检测日志文件的路径是否存在,不存在则创建
|
||||
LOG_PATHS = [
|
||||
f"log/{HOSTNAME}/ucloud",
|
||||
f"log/{HOSTNAME}/error",
|
||||
f"log/{HOSTNAME}/qr",
|
||||
f"log/{HOSTNAME}/sql",
|
||||
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'ucloud'),
|
||||
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'error'),
|
||||
os.path.join(PROJECT_ROOT, 'log', HOSTNAME, 'qr'),
|
||||
]
|
||||
for path in LOG_PATHS:
|
||||
if not os.path.exists(path):
|
||||
@@ -75,16 +76,6 @@ LOGGING_CONFIG = {
|
||||
'backupCount': 14,
|
||||
'encoding': 'utf-8',
|
||||
},
|
||||
'sql': {
|
||||
'class': 'logging.handlers.TimedRotatingFileHandler',
|
||||
'level': 'INFO',
|
||||
'formatter': 'standard',
|
||||
'filename': f'log/{HOSTNAME}/sql/fcb_photo_review_sql.log',
|
||||
'when': 'midnight',
|
||||
'interval': 1,
|
||||
'backupCount': 14,
|
||||
'encoding': 'utf-8',
|
||||
},
|
||||
},
|
||||
|
||||
# loggers定义了日志记录器
|
||||
@@ -109,10 +100,5 @@ LOGGING_CONFIG = {
|
||||
'level': 'DEBUG',
|
||||
'propagate': False,
|
||||
},
|
||||
'sql': {
|
||||
'handlers': ['console', 'sql'],
|
||||
'level': 'DEBUG',
|
||||
'propagate': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -8,15 +8,13 @@ MAX_WAIT_TIME = 3
|
||||
# 程序异常短信配置
|
||||
ERROR_EMAIL_CONFIG = {
|
||||
# SMTP服务器地址
|
||||
"smtp_server": "smtp.163.com",
|
||||
'smtp_server': 'smtp.163.com',
|
||||
# 连接SMTP的端口
|
||||
"port": 994,
|
||||
'port': 994,
|
||||
# 发件人邮箱地址,请确保开启了SMTP邮件服务!
|
||||
"sender": "EchoLiu618@163.com",
|
||||
'sender': 'EchoLiu618@163.com',
|
||||
# 授权码--用于登录第三方邮件客户端的专用密码,不是邮箱密码
|
||||
"authorization_code": "OKPQLIIVLVGRZYVH",
|
||||
'authorization_code': 'OKPQLIIVLVGRZYVH',
|
||||
# 收件人邮箱地址
|
||||
"receivers": ["1515783401@qq.com"],
|
||||
# 尝试次数
|
||||
"retry_times": 3,
|
||||
'receivers': ['1515783401@qq.com'],
|
||||
}
|
||||
@@ -5,18 +5,18 @@ from email.mime.text import MIMEText
|
||||
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
|
||||
from auto_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
|
||||
from log import HOSTNAME
|
||||
from my_email import ERROR_EMAIL_CONFIG, TRY_TIMES, MIN_WAIT_TIME, MAX_WAIT_TIME
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(TRY_TIMES), wait=wait_random(MIN_WAIT_TIME, MAX_WAIT_TIME), reraise=True,
|
||||
after=lambda x: logging.warning("发送邮件失败!"))
|
||||
after=lambda x: logging.warning('发送邮件失败!'))
|
||||
def send_email(email_config, massage):
|
||||
smtp_server = email_config["smtp_server"]
|
||||
port = email_config["port"]
|
||||
sender = email_config["sender"]
|
||||
authorization_code = email_config["authorization_code"]
|
||||
receivers = email_config["receivers"]
|
||||
smtp_server = email_config['smtp_server']
|
||||
port = email_config['port']
|
||||
sender = email_config['sender']
|
||||
authorization_code = email_config['authorization_code']
|
||||
receivers = email_config['receivers']
|
||||
mail = smtplib.SMTP_SSL(smtp_server, port) # 连接SMTP服务
|
||||
mail.login(sender, authorization_code) # 登录到SMTP服务
|
||||
mail.sendmail(sender, receivers, massage.as_string()) # 发送邮件
|
||||
@@ -34,13 +34,13 @@ def send_error_email(program_name, error_name, error_detail):
|
||||
"""
|
||||
|
||||
# SMTP 服务器配置
|
||||
sender = ERROR_EMAIL_CONFIG["sender"]
|
||||
receivers = ERROR_EMAIL_CONFIG["receivers"]
|
||||
sender = ERROR_EMAIL_CONFIG['sender']
|
||||
receivers = ERROR_EMAIL_CONFIG['receivers']
|
||||
|
||||
# 获取程序出错的时间
|
||||
error_time = datetime.datetime.strftime(datetime.datetime.today(), "%Y-%m-%d %H:%M:%S:%f")
|
||||
error_time = datetime.datetime.strftime(datetime.datetime.today(), '%Y-%m-%d %H:%M:%S:%f')
|
||||
# 邮件内容
|
||||
subject = f"【程序异常提醒】{program_name}({HOSTNAME}) {error_time}" # 邮件的标题
|
||||
subject = f'【程序异常提醒】{program_name}({HOSTNAME}) {error_time}' # 邮件的标题
|
||||
content = f'''<div class="emailcontent" style="width:100%;max-width:720px;text-align:left;margin:0 auto;padding-top:80px;padding-bottom:20px">
|
||||
<div class="emailtitle">
|
||||
<h1 style="color:#fff;background:#51a0e3;line-height:70px;font-size:24px;font-weight:400;padding-left:40px;margin:0">程序运行异常通知</h1>
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1,4 +0,0 @@
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
PADDLE_DET = InferenceSession("model/object_det_model/ppyoloe_plus_crn_l_80e_coco_w_nms.onnx",
|
||||
providers=["CPUExecutionProvider"], provider_options=[{"device_id": 0}])
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_cn.md
|
||||
@@ -1 +0,0 @@
|
||||
README_en.md
|
||||
@@ -1,76 +0,0 @@
|
||||
import base64
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_random
|
||||
|
||||
from paddle_detection import PADDLE_DET
|
||||
from paddle_detection.deploy.third_engine.onnx.infer import PredictConfig
|
||||
from paddle_detection.deploy.third_engine.onnx.preprocess import Compose
|
||||
from util import image_util, util
|
||||
|
||||
|
||||
def predict_image(infer_config, predictor, img_path):
|
||||
# load preprocess transforms
|
||||
transforms = Compose(infer_config.preprocess_infos)
|
||||
# predict image
|
||||
inputs = transforms(img_path)
|
||||
inputs["image"] = np.array(inputs["image"]).astype('float32')
|
||||
inputs_name = [var.name for var in predictor.get_inputs()]
|
||||
inputs = {k: inputs[k][None,] for k in inputs_name}
|
||||
|
||||
outputs = predictor.run(output_names=None, input_feed=inputs)
|
||||
|
||||
bboxes = np.array(outputs[0])
|
||||
result = defaultdict(list)
|
||||
for bbox in bboxes:
|
||||
if bbox[0] > -1 and bbox[1] > infer_config.draw_threshold:
|
||||
result[bbox[0]].append({"score": bbox[1], "box": bbox[2:]})
|
||||
return result
|
||||
|
||||
|
||||
def detect_image(img_path):
|
||||
infer_cfg = "model/object_det_model/infer_cfg.yml"
|
||||
# load infer config
|
||||
infer_config = PredictConfig(infer_cfg)
|
||||
|
||||
return predict_image(infer_config, PADDLE_DET, img_path)
|
||||
|
||||
|
||||
def get_book_areas(image):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
detect_result = detect_image(temp_file.name)
|
||||
util.delete_temp_file(temp_file.name)
|
||||
book_areas = detect_result[73]
|
||||
result = []
|
||||
for book_area in book_areas:
|
||||
result.append(image_util.capture(image, book_area["box"]))
|
||||
return result
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_random(1, 3), reraise=True,
|
||||
after=lambda x: logging.warning("获取文档区域失败!"))
|
||||
def request_book_areas(image):
|
||||
url = "http://det_api:5000/det/detect_books"
|
||||
_, encoded_image = cv2.imencode('.jpg', image)
|
||||
byte_stream = encoded_image.tobytes()
|
||||
files = {"image": ("image.jpg", byte_stream)}
|
||||
response = requests.post(url, files=files)
|
||||
if response.status_code == 200:
|
||||
img_str_list = response.json()
|
||||
result = []
|
||||
for img_str in img_str_list:
|
||||
img_data = base64.b64decode(img_str)
|
||||
np_array = np.frombuffer(img_data, np.uint8)
|
||||
img = cv2.imdecode(np_array, cv2.IMREAD_COLOR)
|
||||
height, width = img.shape[:2]
|
||||
if max(height, width) / min(height, width) <= 6.5:
|
||||
result.append(img) # 过滤异常结果
|
||||
return result
|
||||
else:
|
||||
return []
|
||||
@@ -5,35 +5,36 @@ from time import sleep
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
from auto_email.error_email import send_error_email
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhhd
|
||||
from log import LOGGING_CONFIG
|
||||
from my_email.error_email import send_error_email
|
||||
from photo_mask import auto_photo_mask, SEND_ERROR_EMAIL
|
||||
|
||||
if __name__ == '__main__':
|
||||
program_name = "照片审核自动涂抹脚本"
|
||||
program_name = '照片审核自动涂抹脚本'
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
logging.info('等待接口服务启动...')
|
||||
sleep(60)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--clean", default=False, type=bool, help="是否将涂抹中的案子改为待涂抹状态")
|
||||
parser.add_argument('--clean', default=False, type=bool, help='是否将涂抹中的案子改为待涂抹状态')
|
||||
args = parser.parse_args()
|
||||
if args.clean:
|
||||
# 主要用于启动时,清除仍在涂抹中的案子
|
||||
session = MysqlSession()
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == "2").values(paint_flag="1"))
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.paint_flag == '2').values(paint_flag='1'))
|
||||
session.execute(update_flag)
|
||||
session.commit()
|
||||
session.close()
|
||||
logging.info("已释放残余的涂抹案子!")
|
||||
else:
|
||||
sleep(5)
|
||||
logging.info('已释放残余的涂抹案子!')
|
||||
|
||||
try:
|
||||
logging.info(f"【{program_name}】开始运行")
|
||||
logging.info(f'【{program_name}】开始运行')
|
||||
auto_photo_mask.main()
|
||||
except Exception as e:
|
||||
error_logger = logging.getLogger("error")
|
||||
error_logger = logging.getLogger('error')
|
||||
error_logger.error(traceback.format_exc())
|
||||
if SEND_ERROR_EMAIL:
|
||||
send_error_email(program_name, repr(e), traceback.format_exc())
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
"""
|
||||
项目配置
|
||||
"""
|
||||
@@ -40,13 +38,3 @@ SIMILAR_CHAR = {
|
||||
"侯": ["候"],
|
||||
"宇": ["字"],
|
||||
}
|
||||
|
||||
# 如果不希望识别出空格,可以设置use_space_char=False。做此项设置一定要测试,2.7.3版本此项设置有bug,会导致识别失败
|
||||
OCR = PaddleOCR(
|
||||
gpu_id=0,
|
||||
show_log=False,
|
||||
det_db_thresh=0.1,
|
||||
det_db_box_thresh=0.3,
|
||||
det_limit_side_len=1248,
|
||||
drop_score=0.3
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import logging.config
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from time import sleep
|
||||
|
||||
import cv2
|
||||
@@ -10,9 +12,10 @@ from sqlalchemy import update, and_
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhrec, ZxPhhd
|
||||
from log import HOSTNAME
|
||||
from photo_mask import OCR, PHHD_BATCH_SIZE, SLEEP_MINUTES, NAME_KEYS, ID_CARD_NUM_KEYS, SIMILAR_CHAR
|
||||
from photo_mask import PHHD_BATCH_SIZE, SLEEP_MINUTES, NAME_KEYS, ID_CARD_NUM_KEYS, SIMILAR_CHAR
|
||||
from photo_review import set_batch_id
|
||||
from ucloud import BUCKET, ufile
|
||||
from util import image_util, util
|
||||
from util import image_util, common_util, model_util
|
||||
|
||||
|
||||
def find_boxes(content, layout, offset=0, length=None, improve=False, image_path=None, extra_content=None):
|
||||
@@ -55,14 +58,15 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
|
||||
if improve:
|
||||
# 再次识别,提高精度
|
||||
image = cv2.imread(image_path)
|
||||
img_name, img_ext = common_util.parse_save_path(image_path)
|
||||
# 截图时偏大一点
|
||||
capture_box = util.zoom_rectangle(box, 0.2)
|
||||
capture_box = common_util.zoom_rectangle(box, 0.2)
|
||||
captured_image = image_util.capture(image, capture_box)
|
||||
with tempfile.NamedTemporaryFile(suffix=".jpg", delete=False) as temp_file:
|
||||
captured_image, offset_x, offset_y = image_util.expand_to_a4_size(captured_image)
|
||||
cv2.imwrite(temp_file.name, captured_image)
|
||||
captured_image_path = common_util.get_processed_img_path(f'{img_name}.capture.{img_ext}')
|
||||
cv2.imwrite(captured_image_path, captured_image)
|
||||
captured_a4_img_path, offset_x, offset_y = image_util.expand_to_a4_size(captured_image_path)
|
||||
try:
|
||||
layouts = util.get_ocr_layout(OCR, temp_file.name)
|
||||
layouts = common_util.ocr_result_to_layout(model_util.ocr(captured_a4_img_path))
|
||||
except TypeError:
|
||||
# 如果是类型错误,大概率是没识别到文字
|
||||
layouts = []
|
||||
@@ -86,22 +90,17 @@ def find_boxes(content, layout, offset=0, length=None, improve=False, image_path
|
||||
temp_box[3] + capture_box[1] - offset_y,
|
||||
])
|
||||
break
|
||||
util.delete_temp_file(temp_file.name)
|
||||
|
||||
if not boxes:
|
||||
boxes.append(box)
|
||||
return boxes
|
||||
|
||||
|
||||
def get_mask_layout(image, name, id_card_num):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
|
||||
def get_mask_layout(img_path, name, id_card_num):
|
||||
result = []
|
||||
try:
|
||||
try:
|
||||
layouts = util.get_ocr_layout(OCR, temp_file.name)
|
||||
# layouts = OCR.parse({"doc": temp_file.name})["layout"]
|
||||
layouts = common_util.ocr_result_to_layout(model_util.ocr(img_path))
|
||||
except TypeError:
|
||||
# 如果是类型错误,大概率是没识别到文字
|
||||
layouts = []
|
||||
@@ -135,12 +134,12 @@ def get_mask_layout(image, name, id_card_num):
|
||||
find_id_card_num_by_key = True
|
||||
matches = re.findall(r, layout[1])
|
||||
for match in matches:
|
||||
result += find_boxes(match, layout, improve=True, image_path=temp_file.name, extra_content=r)
|
||||
result += find_boxes(match, layout, improve=True, image_path=img_path, extra_content=r)
|
||||
find_name_by_key = False
|
||||
break
|
||||
|
||||
if id_card_num in layout[1]:
|
||||
result += find_boxes(id_card_num, layout, improve=True, image_path=temp_file.name)
|
||||
result += find_boxes(id_card_num, layout, improve=True, image_path=img_path)
|
||||
find_id_card_num_by_key = False
|
||||
|
||||
def _find_boxes_by_keys(keys):
|
||||
@@ -160,13 +159,9 @@ def get_mask_layout(image, name, id_card_num):
|
||||
result += _find_boxes_by_keys(ID_CARD_NUM_KEYS)
|
||||
|
||||
return result
|
||||
except MemoryError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.error("涂抹时出错!", exc_info=e)
|
||||
return result
|
||||
finally:
|
||||
util.delete_temp_file(temp_file.name)
|
||||
|
||||
|
||||
def handle_image_for_mask(split_result):
|
||||
@@ -176,11 +171,15 @@ def handle_image_for_mask(split_result):
|
||||
return expand_img, split_result["x_offset"], split_result["y_offset"]
|
||||
|
||||
|
||||
def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
|
||||
def _mask(i, n, icn, c):
|
||||
def mask_photo(img_path, name, id_card_num, color=(255, 255, 255)):
|
||||
def _mask(ip, n, icn, c):
|
||||
i = cv2.imread(ip)
|
||||
img_name, img_ext = common_util.parse_save_path(ip)
|
||||
do_mask = False
|
||||
split_results = image_util.split(i)
|
||||
split_results = image_util.split(ip)
|
||||
for split_result in split_results:
|
||||
if not split_result['img']:
|
||||
continue
|
||||
to_mask_img, x_offset, y_offset = handle_image_for_mask(split_result)
|
||||
results = get_mask_layout(to_mask_img, n, icn)
|
||||
|
||||
@@ -195,29 +194,27 @@ def mask_photo(img_url, name, id_card_num, color=(255, 255, 255)):
|
||||
result[3] + y_offset,
|
||||
)
|
||||
cv2.rectangle(i, (int(result[0]), int(result[1])), (int(result[2]), int(result[3])), c, -1, 0)
|
||||
return do_mask, i
|
||||
masked_path = common_util.get_processed_img_path(f'{img_name}.mask.{img_ext}')
|
||||
cv2.imwrite(masked_path, i)
|
||||
return do_mask, masked_path
|
||||
|
||||
# 打开图片
|
||||
image = image_util.read(img_url)
|
||||
if image is None:
|
||||
return False, image
|
||||
original_image = image
|
||||
is_masked, image = _mask(image, name, id_card_num, color)
|
||||
original_image = img_path
|
||||
is_masked, img_path = _mask(img_path, name, id_card_num, color)
|
||||
if not is_masked:
|
||||
# 如果没有涂抹,可能是图片方向不对
|
||||
angles = image_util.parse_rotation_angles(image)
|
||||
angles = model_util.clas_orientation(img_path)
|
||||
angle = angles[0]
|
||||
if angle != "0":
|
||||
image = image_util.rotate(image, int(angle))
|
||||
is_masked, image = _mask(image, name, id_card_num, color)
|
||||
img_path = image_util.rotate(img_path, int(angle))
|
||||
is_masked, img_path = _mask(img_path, name, id_card_num, color)
|
||||
if not is_masked:
|
||||
# 如果旋转后也没有涂抹,恢复原来的方向
|
||||
image = original_image
|
||||
img_path = original_image
|
||||
else:
|
||||
# 如果旋转有效果,打一个日志
|
||||
logging.info(f"图片旋转了{angle}°")
|
||||
|
||||
return is_masked, image
|
||||
return is_masked, img_path
|
||||
|
||||
|
||||
def photo_mask(pk_phhd, name, id_card_num):
|
||||
@@ -227,32 +224,37 @@ def photo_mask(pk_phhd, name, id_card_num):
|
||||
ZxPhrec.cRectype.in_(["3", "4"])
|
||||
)).all()
|
||||
session.close()
|
||||
# 同一批图的标识
|
||||
set_batch_id(uuid.uuid4().hex)
|
||||
processed_img_dir = common_util.get_processed_img_path('')
|
||||
os.makedirs(processed_img_dir, exist_ok=True)
|
||||
for phrec in phrecs:
|
||||
img_url = ufile.get_private_url(phrec.cfjaddress)
|
||||
if not img_url:
|
||||
continue
|
||||
|
||||
is_masked, image = mask_photo(img_url, name, id_card_num)
|
||||
original_img_path = common_util.save_to_local(img_url)
|
||||
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
|
||||
shutil.copy2(original_img_path, img_path)
|
||||
is_masked, image = mask_photo(img_path, name, id_card_num)
|
||||
|
||||
# 如果涂抹了要备份以及更新
|
||||
if is_masked:
|
||||
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
try:
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
ufile.upload_file(phrec.cfjaddress, image)
|
||||
session = MysqlSession()
|
||||
update_flag = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
|
||||
paint_user=HOSTNAME,
|
||||
paint_date=util.get_default_datetime()))
|
||||
paint_date=common_util.get_default_datetime()))
|
||||
session.execute(update_flag)
|
||||
session.commit()
|
||||
session.close()
|
||||
except Exception as e:
|
||||
logging.error("上传图片出错", exc_info=e)
|
||||
finally:
|
||||
util.delete_temp_file(temp_file.name)
|
||||
|
||||
# 删除多余图片
|
||||
if os.path.exists(processed_img_dir) and os.path.isdir(processed_img_dir):
|
||||
shutil.rmtree(processed_img_dir)
|
||||
|
||||
|
||||
def main():
|
||||
@@ -282,7 +284,7 @@ def main():
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(
|
||||
paint_flag="8",
|
||||
paint_user=HOSTNAME,
|
||||
paint_date=util.get_default_datetime(),
|
||||
paint_date=common_util.get_default_datetime(),
|
||||
fZcfwfy=time.time() - start_time))
|
||||
session.execute(update_flag)
|
||||
session.commit()
|
||||
|
||||
@@ -8,7 +8,7 @@ from db import MysqlSession
|
||||
from db.mysql import ZxIeOcrerror, ZxPhrec
|
||||
from photo_mask.auto_photo_mask import mask_photo
|
||||
from ucloud import ufile
|
||||
from util import image_util, util
|
||||
from util import image_util, common_util
|
||||
|
||||
|
||||
def check_error(error_ocr):
|
||||
@@ -91,7 +91,7 @@ if __name__ == '__main__':
|
||||
|
||||
session = MysqlSession()
|
||||
update_error = (update(ZxIeOcrerror).where(ZxIeOcrerror.pk_phrec == ocr_error.pk_phrec).values(
|
||||
checktime=util.get_default_datetime(), cfjaddress2=error_descript))
|
||||
checktime=common_util.get_default_datetime(), cfjaddress2=error_descript))
|
||||
session.execute(update_error)
|
||||
session.commit()
|
||||
session.close()
|
||||
|
||||
@@ -7,7 +7,7 @@ from sqlalchemy import update, and_
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxIeOcrerror
|
||||
from photo_mask.photo_mask_error_check import auto_check_error
|
||||
from util import util
|
||||
from util import common_util
|
||||
|
||||
if __name__ == '__main__':
|
||||
today = date.today()
|
||||
@@ -29,7 +29,7 @@ if __name__ == '__main__':
|
||||
if error_descript == "未知错误":
|
||||
check_time = None
|
||||
else:
|
||||
check_time = util.get_default_datetime()
|
||||
check_time = common_util.get_default_datetime()
|
||||
|
||||
session = MysqlSession()
|
||||
update_error = (update(ZxIeOcrerror).where(ZxIeOcrerror.pk_phrec == ocr_error.pk_phrec).values(
|
||||
@@ -41,5 +41,5 @@ if __name__ == '__main__':
|
||||
print(result)
|
||||
with open("photo_mask_error_report.txt", 'w', encoding='utf-8') as file:
|
||||
file.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
file.write(util.get_default_datetime())
|
||||
file.write(common_util.get_default_datetime())
|
||||
print("结果已保存。")
|
||||
|
||||
@@ -5,36 +5,36 @@ from time import sleep
|
||||
|
||||
from sqlalchemy import update
|
||||
|
||||
from auto_email.error_email import send_error_email
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhhd
|
||||
from log import LOGGING_CONFIG
|
||||
from my_email.error_email import send_error_email
|
||||
from photo_review import auto_photo_review, SEND_ERROR_EMAIL
|
||||
|
||||
# 项目必须从此处启动,否则代码中的相对路径可能导致错误的发生
|
||||
# 照片审核自动识别脚本入口
|
||||
if __name__ == '__main__':
|
||||
program_name = '照片审核自动识别脚本'
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
|
||||
logging.info('等待接口服务启动...')
|
||||
sleep(60)
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--clean", default=False, type=bool, help="是否将识别中的案子改为待识别状态")
|
||||
parser.add_argument('--clean', default=False, type=bool, help='是否将识别中的案子改为待识别状态')
|
||||
args = parser.parse_args()
|
||||
if args.clean:
|
||||
# 主要用于启动时,清除仍在涂抹中的案子
|
||||
# 启动时清除仍在识别中的案子
|
||||
session = MysqlSession()
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == "2").values(exsuccess_flag="1"))
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.exsuccess_flag == '2').values(exsuccess_flag='1'))
|
||||
session.execute(update_flag)
|
||||
session.commit()
|
||||
session.close()
|
||||
logging.info("已释放残余的识别案子!")
|
||||
else:
|
||||
sleep(5)
|
||||
logging.info('已释放残余的识别案子!')
|
||||
|
||||
try:
|
||||
logging.info(f"【{program_name}】开始运行")
|
||||
logging.info(f'【{program_name}】开始运行')
|
||||
auto_photo_review.main()
|
||||
except Exception as e:
|
||||
error_logger = logging.getLogger('error')
|
||||
error_logger.error(traceback.format_exc())
|
||||
logging.getLogger('error').error(traceback.format_exc())
|
||||
if SEND_ERROR_EMAIL:
|
||||
send_error_email(program_name, repr(e), traceback.format_exc())
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import jieba
|
||||
from paddlenlp import Taskflow
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
'''
|
||||
项目配置
|
||||
@@ -11,52 +9,8 @@ PHHD_BATCH_SIZE = 10
|
||||
SLEEP_MINUTES = 5
|
||||
# 是否发送报错邮件
|
||||
SEND_ERROR_EMAIL = True
|
||||
# 是否开启布局分析
|
||||
LAYOUT_ANALYSIS = False
|
||||
|
||||
"""
|
||||
信息抽取关键词配置
|
||||
"""
|
||||
# 患者姓名
|
||||
PATIENT_NAME = ['患者姓名']
|
||||
# 入院日期
|
||||
ADMISSION_DATE = ['入院日期']
|
||||
# 出院日期
|
||||
DISCHARGE_DATE = ['出院日期']
|
||||
# 发生医疗费
|
||||
MEDICAL_EXPENSES = ['费用总额']
|
||||
# 个人现金支付
|
||||
PERSONAL_CASH_PAYMENT = ['个人现金支付']
|
||||
# 个人账户支付
|
||||
PERSONAL_ACCOUNT_PAYMENT = ['个人账户支付']
|
||||
# 个人自费金额
|
||||
PERSONAL_FUNDED_AMOUNT = ['自费金额', '个人自费']
|
||||
# 医保类别
|
||||
MEDICAL_INSURANCE_TYPE = ['医保类型']
|
||||
# 就诊医院
|
||||
HOSPITAL = ['医院']
|
||||
# 就诊科室
|
||||
DEPARTMENT = ['科室']
|
||||
# 主治医生
|
||||
DOCTOR = ['主治医生']
|
||||
# 住院号
|
||||
ADMISSION_ID = ['住院号']
|
||||
# 医保结算单号码
|
||||
SETTLEMENT_ID = ['医保结算单号码']
|
||||
# 年龄
|
||||
AGE = ['年龄']
|
||||
# 大写总额
|
||||
UPPERCASE_MEDICAL_EXPENSES = ['大写总额']
|
||||
|
||||
SETTLEMENT_LIST_SCHEMA = \
|
||||
(PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES + PERSONAL_CASH_PAYMENT
|
||||
+ PERSONAL_ACCOUNT_PAYMENT + PERSONAL_FUNDED_AMOUNT + MEDICAL_INSURANCE_TYPE + ADMISSION_ID + SETTLEMENT_ID
|
||||
+ UPPERCASE_MEDICAL_EXPENSES)
|
||||
|
||||
DISCHARGE_RECORD_SCHEMA = \
|
||||
HOSPITAL + DEPARTMENT + PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + DOCTOR + ADMISSION_ID + AGE
|
||||
|
||||
COST_LIST_SCHEMA = PATIENT_NAME + ADMISSION_DATE + DISCHARGE_DATE + MEDICAL_EXPENSES
|
||||
# 处理批号(这里仅起声明作用)
|
||||
BATCH_ID = ''
|
||||
|
||||
'''
|
||||
别名配置
|
||||
@@ -92,23 +46,32 @@ jieba.suggest_freq(('胆', '道'), True)
|
||||
jieba.suggest_freq(('脾', '胃'), True)
|
||||
|
||||
'''
|
||||
模型配置
|
||||
出院记录缺页判断关键词配置
|
||||
'''
|
||||
SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA, model='uie-x-base',
|
||||
task_path='model/settlement_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
DISCHARGE_IE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, model='uie-x-base',
|
||||
task_path='model/discharge_record_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
COST_IE = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base', device_id=1,
|
||||
task_path='model/cost_list_model', layout_analysis=LAYOUT_ANALYSIS, precision='fp16')
|
||||
DISCHARGE_KEY = {
|
||||
'入院诊断': ['入院诊断'],
|
||||
'入院情况': ['入院情况', '入院时情况', '入院时主要症状'],
|
||||
'入院日期': ['入院日期', '入院时间'],
|
||||
'诊疗经过': ['诊疗经过', '住院经过', '治疗经过'],
|
||||
'出院诊断': ['出院诊断'],
|
||||
'出院情况': ['出院情况', '出院时情况'],
|
||||
'出院日期': ['出院日期', '出院时间'],
|
||||
'出院医嘱': ['出院医嘱', '出院医瞩']
|
||||
}
|
||||
|
||||
OCR = PaddleOCR(
|
||||
gpu_id=1,
|
||||
use_angle_cls=False,
|
||||
show_log=False,
|
||||
det_db_thresh=0.1,
|
||||
det_db_box_thresh=0.3,
|
||||
det_limit_side_len=1248,
|
||||
drop_score=0.3,
|
||||
rec_model_dir='model/ocr/openatom_rec_repsvtr_ch_infer',
|
||||
rec_algorithm='SVTR_LCNet',
|
||||
)
|
||||
|
||||
def get_batch_id():
|
||||
"""
|
||||
获取处理批号
|
||||
:return: 处理批号
|
||||
"""
|
||||
return BATCH_ID
|
||||
|
||||
|
||||
def set_batch_id(batch_id):
|
||||
"""
|
||||
修改处理批号哦
|
||||
:param batch_id: 新批号
|
||||
"""
|
||||
global BATCH_ID
|
||||
BATCH_ID = batch_id
|
||||
|
||||
@@ -1,119 +1,76 @@
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import re
|
||||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from time import sleep
|
||||
|
||||
import cv2
|
||||
import fitz
|
||||
import jieba
|
||||
import numpy as np
|
||||
import requests
|
||||
import zxingcpp
|
||||
from rapidfuzz import process, fuzz
|
||||
from sqlalchemy import update
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import BdYljg, BdYlks, ZxIeResult, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec
|
||||
from db.mysql import BdYljg, BdYlks, ZxIeCost, ZxIeDischarge, ZxIeSettlement, ZxPhhd, ZxPhrec, ZxIeReview, ZxIeResult
|
||||
from log import HOSTNAME
|
||||
from photo_review import PATIENT_NAME, ADMISSION_DATE, DISCHARGE_DATE, MEDICAL_EXPENSES, PERSONAL_CASH_PAYMENT, \
|
||||
PERSONAL_ACCOUNT_PAYMENT, PERSONAL_FUNDED_AMOUNT, MEDICAL_INSURANCE_TYPE, HOSPITAL, DEPARTMENT, DOCTOR, \
|
||||
ADMISSION_ID, SETTLEMENT_ID, AGE, OCR, SETTLEMENT_IE, DISCHARGE_IE, COST_IE, PHHD_BATCH_SIZE, SLEEP_MINUTES, \
|
||||
UPPERCASE_MEDICAL_EXPENSES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, DEPARTMENT_FILTER
|
||||
from photo_review import PHHD_BATCH_SIZE, SLEEP_MINUTES, HOSPITAL_ALIAS, HOSPITAL_FILTER, DEPARTMENT_ALIAS, \
|
||||
DEPARTMENT_FILTER, DISCHARGE_KEY, set_batch_id, get_batch_id
|
||||
from services.paddle_services import IE_KEY
|
||||
from ucloud import ufile, BUCKET
|
||||
from util import image_util, util, html_util
|
||||
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, \
|
||||
handle_insurance_type, handle_original_data, handle_hospital, handle_department, handle_age, parse_money, \
|
||||
parse_hospital, handle_doctor, handle_text, handle_admission_id, handle_settlement_id
|
||||
from util import image_util, common_util, html_util, model_util
|
||||
from util.data_util import handle_date, handle_decimal, parse_department, handle_name, handle_insurance_type, \
|
||||
handle_original_data, handle_hospital, handle_department, handle_id, handle_age, parse_money, parse_hospital, \
|
||||
parse_page_num, handle_tiny_int
|
||||
|
||||
|
||||
# 合并信息抽取结果
|
||||
def merge_result(result1, result2):
|
||||
for key in result2:
|
||||
result1[key] = result1.get(key, []) + result2[key]
|
||||
return result1
|
||||
def parse_qrcode(img_path, image_id):
|
||||
"""
|
||||
解析二维码,尝试从中获取高清图片
|
||||
:param img_path: 待解析图片
|
||||
:param image_id: 图片id
|
||||
:return: 解析结果
|
||||
"""
|
||||
|
||||
|
||||
def ie_temp_image(ie, ocr, image):
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
|
||||
ie_result = []
|
||||
ocr_pure_text = ''
|
||||
try:
|
||||
layout = util.get_ocr_layout(ocr, temp_file.name)
|
||||
if not layout:
|
||||
# 无识别结果
|
||||
ie_result = []
|
||||
else:
|
||||
ie_result = ie({"doc": temp_file.name, "layout": layout})[0]
|
||||
for lay in layout:
|
||||
ocr_pure_text += lay[1]
|
||||
except MemoryError as e:
|
||||
# 显存不足时应该抛出错误,让程序重启,同时释放显存
|
||||
raise e
|
||||
except Exception as e:
|
||||
logging.error("信息抽取时出错", exc_info=e)
|
||||
finally:
|
||||
try:
|
||||
os.remove(temp_file.name)
|
||||
except Exception as e:
|
||||
logging.info(f"删除临时文件 {temp_file.name} 时出错", exc_info=e)
|
||||
return ie_result, ocr_pure_text
|
||||
|
||||
|
||||
# 关键信息提取
|
||||
def request_ie_result(task_enum, phrecs):
|
||||
url = task_enum.request_url()
|
||||
identity = int(time.time())
|
||||
images = []
|
||||
for phrec in phrecs:
|
||||
images.append({"name": phrec.cfjaddress, "pk": phrec.pk_phrec})
|
||||
payload = {"images": images, "schema": task_enum.schema(), "pk_phhd": phrecs[0].pk_phhd, "identity": identity}
|
||||
response = requests.post(url, json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
return response.json()["data"]
|
||||
else:
|
||||
raise Exception(f"请求信息抽取结果失败,状态码:{response.status_code}")
|
||||
|
||||
|
||||
# 尝试从二维码中获取高清图片
|
||||
def get_better_image_from_qrcode(image, image_id, dpi=150):
|
||||
def _parse_pdf_url(pdf_url_to_parse):
|
||||
pdf_file = None
|
||||
local_pdf_path = None
|
||||
img_name, img_ext = common_util.parse_save_path(img_path)
|
||||
try:
|
||||
local_pdf_path = html_util.download_pdf(pdf_url_to_parse)
|
||||
# 打开PDF文件
|
||||
pdf_file = fitz.open(local_pdf_path)
|
||||
# 选择第一页
|
||||
page = pdf_file[0]
|
||||
# 定义缩放系数(DPI)
|
||||
default_dpi = 72
|
||||
zoom = dpi / default_dpi
|
||||
# 设置矩阵变换参数
|
||||
mat = fitz.Matrix(zoom, zoom)
|
||||
# 渲染页面
|
||||
pix = page.get_pixmap(matrix=mat)
|
||||
# 将渲染结果转换为OpenCV兼容的格式
|
||||
img = np.frombuffer(pix.samples, dtype=np.uint8).reshape((pix.height, pix.width, -1))
|
||||
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
||||
return img, page.get_text()
|
||||
pdf_imgs = image_util.pdf_to_imgs(local_pdf_path)
|
||||
# 结算单部分
|
||||
better_settlement_path = common_util.get_processed_img_path(f'{img_name}.better_settlement.jpg')
|
||||
cv2.imwrite(better_settlement_path, pdf_imgs[0][0])
|
||||
# 费用清单部分
|
||||
better_cost_path = common_util.get_processed_img_path(f'{img_name}.better_cost.jpg')
|
||||
total_height = sum([p[0].shape[0] for p in pdf_imgs[1:]])
|
||||
common_width = pdf_imgs[1][0].shape[1]
|
||||
better_cost_img = np.zeros((total_height, common_width, 3), dtype=np.uint8)
|
||||
current_y = 0
|
||||
for pdf in pdf_imgs[1:]:
|
||||
height = pdf[0].shape[0]
|
||||
better_cost_img[current_y:current_y + height, :, :] = pdf[0]
|
||||
current_y += height
|
||||
# cost_text += pdf[1] # 费用清单文本暂时没用到
|
||||
cv2.imwrite(better_cost_path, better_cost_img)
|
||||
|
||||
return better_settlement_path, pdf_imgs[0][1], better_cost_path
|
||||
except Exception as ex:
|
||||
logging.getLogger('error').error('解析pdf失败!', exc_info=ex)
|
||||
return None, None
|
||||
return None, None, None
|
||||
finally:
|
||||
if pdf_file:
|
||||
pdf_file.close()
|
||||
if local_pdf_path:
|
||||
util.delete_temp_file(local_pdf_path)
|
||||
common_util.delete_temp_file(local_pdf_path)
|
||||
|
||||
jsczt_base_url = 'http://einvoice.jsczt.cn'
|
||||
try:
|
||||
results = zxingcpp.read_barcodes(image)
|
||||
img = cv2.imread(img_path)
|
||||
results = zxingcpp.read_barcodes(img, text_mode=zxingcpp.TextMode.HRI)
|
||||
except Exception as e:
|
||||
logging.getLogger('error').info('二维码识别失败', exc_info=e)
|
||||
results = []
|
||||
@@ -138,159 +95,122 @@ def get_better_image_from_qrcode(image, image_id, dpi=150):
|
||||
if not pdf_url:
|
||||
continue
|
||||
return _parse_pdf_url(pdf_url)
|
||||
elif url.startswith('http://weixin.qq.com'):
|
||||
elif (url.startswith('http://weixin.qq.com')
|
||||
or url == 'https://ybj.jszwfw.gov.cn/hsa-app-panel/index.html'):
|
||||
# 无效地址
|
||||
continue
|
||||
elif url.startswith('http://dzpj.ntzyy.com'):
|
||||
# 南通市中医院
|
||||
return _parse_pdf_url(url)
|
||||
# elif url.startswith('https://apph5.ztejsapp.cn/nj/view/elecInvoiceForOther/QRCode2Invoice'):
|
||||
# pdf_url = html_util.get_dtsrmyy_pdf_url(url)
|
||||
# if not pdf_url:
|
||||
# continue
|
||||
# return _parse_pdf_url(pdf_url)
|
||||
else:
|
||||
logging.getLogger('qr').info(f'[{image_id}]中有未知二维码内容:{url}')
|
||||
except Exception as e:
|
||||
logging.getLogger('error').error('从二维码中获取高清图片时出错', exc_info=e)
|
||||
continue
|
||||
|
||||
return None, None
|
||||
return None, None, None
|
||||
|
||||
|
||||
# 关键信息提取
|
||||
def information_extraction(ie, phrecs, identity):
|
||||
result = {}
|
||||
def information_extraction(phrec, pk_phhd):
|
||||
"""
|
||||
处理单张图片
|
||||
:param phrec:图片信息
|
||||
:param pk_phhd:案子主键
|
||||
:return:记录类型,信息抽取结果
|
||||
"""
|
||||
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
|
||||
if not os.path.exists(img_path):
|
||||
original_img_path = common_util.get_img_path(phrec.cfjaddress)
|
||||
if not original_img_path:
|
||||
img_url = ufile.get_private_url(phrec.cfjaddress)
|
||||
if not img_url:
|
||||
return None, None, None
|
||||
original_img_path = common_util.save_to_local(img_url)
|
||||
shutil.copy2(original_img_path, img_path)
|
||||
if image_util.is_photo(img_path):
|
||||
book_img_path = model_util.det_book(img_path) # 识别文档区域并裁剪
|
||||
dewarped_img_path = model_util.dewarp(book_img_path) # 去扭曲
|
||||
else: # todo:也可能是图片,后续添加细分逻辑
|
||||
dewarped_img_path = img_path
|
||||
angles = model_util.clas_orientation(dewarped_img_path)
|
||||
ocr_text = ''
|
||||
for phrec in phrecs:
|
||||
img_path = ufile.get_private_url(phrec.cfjaddress)
|
||||
if not img_path:
|
||||
continue
|
||||
|
||||
image = image_util.read(img_path)
|
||||
if image is None:
|
||||
# 图片可能因为某些原因获取不到
|
||||
continue
|
||||
|
||||
# 尝试从二维码中获取高清图片
|
||||
better_image, text = get_better_image_from_qrcode(image, phrec.cfjaddress)
|
||||
if phrec.cRectype != '1':
|
||||
better_image = None # 非结算单暂时不进行替换
|
||||
zx_ie_results = []
|
||||
if better_image is not None:
|
||||
img_angle = '0'
|
||||
image = better_image
|
||||
if text:
|
||||
info_extract = ie(text)[0]
|
||||
else:
|
||||
info_extract = ie_temp_image(ie, OCR, image)[0]
|
||||
ie_result = {'result': info_extract, 'angle': '0'}
|
||||
|
||||
now = util.get_default_datetime()
|
||||
if not ie_result['result']:
|
||||
info_extract = []
|
||||
rec_type = None
|
||||
for angle in angles:
|
||||
ocr_result = []
|
||||
rotated_img = image_util.rotate(dewarped_img_path, int(angle))
|
||||
split_results = image_util.split(rotated_img)
|
||||
for split_result in split_results:
|
||||
if split_result['img'] is None:
|
||||
continue
|
||||
a4_img = image_util.expand_to_a4_size(split_result['img'])
|
||||
tmp_ocr_result = model_util.ocr(a4_img)
|
||||
if tmp_ocr_result:
|
||||
ocr_result += tmp_ocr_result
|
||||
tmp_ocr_text = common_util.ocr_result_to_text(ocr_result)
|
||||
|
||||
result_json = json.dumps(ie_result['result'], ensure_ascii=False)
|
||||
if len(result_json) > 5000:
|
||||
result_json = result_json[:5000]
|
||||
zx_ie_results.append(ZxIeResult(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity,
|
||||
cfjaddress=phrec.cfjaddress, content=result_json,
|
||||
rotation_angle=int(ie_result['angle']),
|
||||
x_offset=0, y_offset=0, create_time=now,
|
||||
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
|
||||
|
||||
result = merge_result(result, ie_result['result'])
|
||||
# if any(key in tmp_ocr_text for key in ['出院记录', '出院小结', '死亡记录']):
|
||||
# tmp_rec_type = '出院记录'
|
||||
# elif any(key in tmp_ocr_text for key in ['费用汇总清单', '费用清单', '费用明细', '结账清单', '费用小项统计']):
|
||||
# tmp_rec_type = '费用清单'
|
||||
# elif any(key in tmp_ocr_text for key in ['住院收费票据', '结算单', '财政部监制', '结算凭证']):
|
||||
# tmp_rec_type = '基本医保结算单'
|
||||
# else:
|
||||
# tmp_rec_type = model_util.clas_text(tmp_ocr_text) if tmp_ocr_text else None
|
||||
# if not tmp_rec_type:
|
||||
rec_dict = {
|
||||
'1': '基本医保结算单',
|
||||
'3': '出院记录',
|
||||
'4': '费用清单',
|
||||
}
|
||||
tmp_rec_type = rec_dict.get(phrec.cRectype)
|
||||
if tmp_rec_type == '基本医保结算单':
|
||||
tmp_info_extract = model_util.ie_settlement(rotated_img, common_util.ocr_result_to_layout(ocr_result))
|
||||
elif tmp_rec_type == '出院记录':
|
||||
tmp_info_extract = model_util.ie_discharge(rotated_img, common_util.ocr_result_to_layout(ocr_result))
|
||||
elif tmp_rec_type == '费用清单':
|
||||
tmp_info_extract = model_util.ie_cost(rotated_img, common_util.ocr_result_to_layout(ocr_result))
|
||||
else:
|
||||
target_images = []
|
||||
# target_images += detector.request_book_areas(image) # 识别文档区域并裁剪
|
||||
if not target_images:
|
||||
target_images.append(image) # 识别失败
|
||||
angle_count = defaultdict(int, {'0': 0}) # 分割后图片的最优角度统计
|
||||
for target_image in target_images:
|
||||
# dewarped_image = dewarp.dewarp_image(target_image) # 去扭曲
|
||||
dewarped_image = target_image
|
||||
angles = image_util.parse_rotation_angles(dewarped_image)
|
||||
tmp_info_extract = []
|
||||
|
||||
split_results = image_util.split(dewarped_image)
|
||||
for split_result in split_results:
|
||||
if split_result['img'] is None or split_result['img'].size == 0:
|
||||
continue
|
||||
rotated_img = image_util.rotate(split_result['img'], int(angles[0]))
|
||||
ie_temp_result = ie_temp_image(ie, OCR, rotated_img)
|
||||
ocr_text += ie_temp_result[1]
|
||||
ie_results = [{'result': ie_temp_result[0], 'angle': angles[0]}]
|
||||
if not ie_results[0]['result'] or len(ie_results[0]['result']) < len(ie.kwargs.get('schema')):
|
||||
rotated_img = image_util.rotate(split_result['img'], int(angles[1]))
|
||||
ie_results.append({'result': ie_temp_image(ie, OCR, rotated_img)[0], 'angle': angles[1]})
|
||||
now = util.get_default_datetime()
|
||||
best_angle = ['0', 0]
|
||||
for ie_result in ie_results:
|
||||
if not ie_result['result']:
|
||||
continue
|
||||
if tmp_info_extract and len(tmp_info_extract) > len(info_extract):
|
||||
info_extract = tmp_info_extract
|
||||
ocr_text = tmp_ocr_text
|
||||
rec_type = tmp_rec_type
|
||||
|
||||
result_json = json.dumps(ie_result['result'], ensure_ascii=False)
|
||||
if len(result_json) > 5000:
|
||||
result_json = result_json[:5000]
|
||||
zx_ie_results.append(ZxIeResult(pk_phhd=phrec.pk_phhd, pk_phrec=phrec.pk_phrec, id=identity,
|
||||
cfjaddress=phrec.cfjaddress, content=result_json,
|
||||
rotation_angle=int(ie_result['angle']),
|
||||
x_offset=split_result['x_offset'],
|
||||
y_offset=split_result['y_offset'], create_time=now,
|
||||
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
|
||||
|
||||
result = merge_result(result, ie_result['result'])
|
||||
|
||||
if len(ie_result['result']) > best_angle[1]:
|
||||
best_angle = [ie_result['angle'], len(ie_result['result'])]
|
||||
|
||||
angle_count[best_angle[0]] += 1
|
||||
img_angle = max(angle_count, key=angle_count.get)
|
||||
|
||||
if img_angle != '0' or better_image is not None:
|
||||
image = image_util.rotate(image, int(img_angle))
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix='.jpg') as temp_file:
|
||||
cv2.imwrite(temp_file.name, image)
|
||||
try:
|
||||
if img_angle != '0':
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
logging.info(f'旋转图片[{phrec.cfjaddress}]替换成功,已旋转{img_angle}度。')
|
||||
# 修正旋转角度
|
||||
for zx_ie_result in zx_ie_results:
|
||||
zx_ie_result.rotation_angle -= int(img_angle)
|
||||
else:
|
||||
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
ufile.upload_file(phrec.cfjaddress, temp_file.name)
|
||||
logging.info(f'高清图片[{phrec.cfjaddress}]替换成功!')
|
||||
except Exception as e:
|
||||
logging.error(f'上传图片({phrec.cfjaddress})失败', exc_info=e)
|
||||
finally:
|
||||
util.delete_temp_file(temp_file.name)
|
||||
if info_extract:
|
||||
result_json = json.dumps(info_extract, ensure_ascii=False)
|
||||
if len(result_json) > 5000:
|
||||
result_json = result_json[:5000]
|
||||
|
||||
now = common_util.get_default_datetime()
|
||||
session = MysqlSession()
|
||||
session.add_all(zx_ie_results)
|
||||
session.add(ZxIeResult(pk_phhd=pk_phhd, pk_phrec=phrec.pk_phrec, id=get_batch_id(),
|
||||
cfjaddress=phrec.cfjaddress, content=result_json, create_time=now,
|
||||
creator=HOSTNAME, update_time=now, updater=HOSTNAME))
|
||||
session.commit()
|
||||
|
||||
# # 添加清晰度测试
|
||||
# if better_image is None:
|
||||
# # 替换后图片默认清晰
|
||||
# clarity_result = image_util.parse_clarity(image)
|
||||
# unsharp_flag = 0 if (clarity_result[0] == 0 and clarity_result[1] >= 0.8) else 1
|
||||
# update_clarity = (update(ZxPhrec).where(ZxPhrec.pk_phrec == phrec.pk_phrec).values(
|
||||
# cfjaddress2=json.dumps(clarity_result),
|
||||
# unsharp_flag=unsharp_flag,
|
||||
# ))
|
||||
# session.execute(update_clarity)
|
||||
# session.commit()
|
||||
# session.close()
|
||||
|
||||
result['ocr_text'] = ocr_text
|
||||
return result
|
||||
session.close()
|
||||
return rec_type, info_extract, ocr_text
|
||||
|
||||
|
||||
# 从keys中获取准确率最高的value
|
||||
def get_best_value_in_keys(source, keys):
|
||||
def get_best_value_of_key(source, key):
|
||||
# 最终结果
|
||||
result = None
|
||||
# 最大可能性
|
||||
best_probability = 0
|
||||
for key in keys:
|
||||
values = source.get(key)
|
||||
if values:
|
||||
for value in values:
|
||||
text = value.get("text")
|
||||
probability = value.get("probability")
|
||||
values = source.get(key)
|
||||
if values:
|
||||
for value in values:
|
||||
for v in value:
|
||||
text = v.get("text")
|
||||
probability = v.get("probability")
|
||||
if text and probability > best_probability:
|
||||
result = text
|
||||
best_probability = probability
|
||||
@@ -298,11 +218,11 @@ def get_best_value_in_keys(source, keys):
|
||||
|
||||
|
||||
# 从keys中获取所有value组成list
|
||||
def get_values_of_keys(source, keys):
|
||||
def get_values_of_key(source, key):
|
||||
result = []
|
||||
for key in keys:
|
||||
value = source.get(key)
|
||||
if value:
|
||||
values = source.get(key)
|
||||
if values:
|
||||
for value in values:
|
||||
for v in value:
|
||||
v = v.get("text")
|
||||
if v:
|
||||
@@ -316,11 +236,11 @@ def save_or_update_ie(table, pk_phhd, data):
|
||||
obj = table(**data)
|
||||
session = MysqlSession()
|
||||
db_data = session.query(table).filter_by(pk_phhd=pk_phhd).one_or_none()
|
||||
now = util.get_default_datetime()
|
||||
now = common_util.get_default_datetime()
|
||||
if db_data:
|
||||
# 更新
|
||||
db_data.update_time = now
|
||||
db_data.updater = HOSTNAME
|
||||
db_data.creator = HOSTNAME
|
||||
for k, v in data.items():
|
||||
setattr(db_data, k, v)
|
||||
else:
|
||||
@@ -391,23 +311,24 @@ def search_department(department):
|
||||
return best_match
|
||||
|
||||
|
||||
def settlement_task(pk_phhd, settlement_list, identity):
|
||||
settlement_list_ie_result = information_extraction(SETTLEMENT_IE, settlement_list, identity)
|
||||
def settlement_task(pk_phhd, settlement_list_ie_result):
|
||||
settlement_data = {
|
||||
"pk_phhd": pk_phhd,
|
||||
"name": handle_name(get_best_value_in_keys(settlement_list_ie_result, PATIENT_NAME)),
|
||||
"admission_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_DATE)),
|
||||
"discharge_date_str": handle_original_data(get_best_value_in_keys(settlement_list_ie_result, DISCHARGE_DATE)),
|
||||
"name": handle_name(get_best_value_of_key(settlement_list_ie_result, IE_KEY['name'])),
|
||||
"admission_date_str": handle_original_data(
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['admission_date'])),
|
||||
"discharge_date_str": handle_original_data(
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['discharge_date'])),
|
||||
"personal_cash_payment_str": handle_original_data(
|
||||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_CASH_PAYMENT)),
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_cash_payment'])),
|
||||
"personal_account_payment_str": handle_original_data(
|
||||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_ACCOUNT_PAYMENT)),
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_account_payment'])),
|
||||
"personal_funded_amount_str": handle_original_data(
|
||||
get_best_value_in_keys(settlement_list_ie_result, PERSONAL_FUNDED_AMOUNT)),
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['personal_funded_amount'])),
|
||||
"medical_insurance_type_str": handle_original_data(
|
||||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_INSURANCE_TYPE)),
|
||||
"admission_id": handle_admission_id(get_best_value_in_keys(settlement_list_ie_result, ADMISSION_ID)),
|
||||
"settlement_id": handle_settlement_id(get_best_value_in_keys(settlement_list_ie_result, SETTLEMENT_ID)),
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['medical_insurance_type'])),
|
||||
"admission_id": handle_id(get_best_value_of_key(settlement_list_ie_result, IE_KEY['admission_id'])),
|
||||
"settlement_id": handle_id(get_best_value_of_key(settlement_list_ie_result, IE_KEY['settlement_id'])),
|
||||
}
|
||||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||||
settlement_data["admission_date"] = handle_date(settlement_data["admission_date_str"])
|
||||
@@ -417,32 +338,30 @@ def settlement_task(pk_phhd, settlement_list, identity):
|
||||
settlement_data["personal_funded_amount"] = handle_decimal(settlement_data["personal_funded_amount_str"])
|
||||
settlement_data["medical_insurance_type"] = handle_insurance_type(settlement_data["medical_insurance_type_str"])
|
||||
|
||||
parse_money_result = parse_money(get_best_value_in_keys(settlement_list_ie_result, UPPERCASE_MEDICAL_EXPENSES),
|
||||
get_best_value_in_keys(settlement_list_ie_result, MEDICAL_EXPENSES))
|
||||
parse_money_result = parse_money(
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['uppercase_medical_expenses']),
|
||||
get_best_value_of_key(settlement_list_ie_result, IE_KEY['medical_expenses']))
|
||||
settlement_data["medical_expenses_str"] = handle_original_data(parse_money_result[0])
|
||||
settlement_data["medical_expenses"] = parse_money_result[1]
|
||||
|
||||
if not settlement_data["settlement_id"]:
|
||||
# 如果没有结算单号就填住院号
|
||||
settlement_data["settlement_id"] = settlement_data["admission_id"]
|
||||
save_or_update_ie(ZxIeSettlement, pk_phhd, settlement_data)
|
||||
return settlement_data
|
||||
|
||||
|
||||
def discharge_task(pk_phhd, discharge_record, identity):
|
||||
discharge_record_ie_result = information_extraction(DISCHARGE_IE, discharge_record, identity)
|
||||
hospitals = get_values_of_keys(discharge_record_ie_result, HOSPITAL)
|
||||
departments = get_values_of_keys(discharge_record_ie_result, DEPARTMENT)
|
||||
def discharge_task(pk_phhd, discharge_record_ie_result):
|
||||
hospitals = get_values_of_key(discharge_record_ie_result, IE_KEY['hospital'])
|
||||
departments = get_values_of_key(discharge_record_ie_result, IE_KEY['department'])
|
||||
discharge_data = {
|
||||
"pk_phhd": pk_phhd,
|
||||
"hospital": handle_hospital(",".join(hospitals)),
|
||||
"department": handle_department(",".join(departments)),
|
||||
"name": handle_name(get_best_value_in_keys(discharge_record_ie_result, PATIENT_NAME)),
|
||||
"admission_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_DATE)),
|
||||
"discharge_date_str": handle_original_data(get_best_value_in_keys(discharge_record_ie_result, DISCHARGE_DATE)),
|
||||
"doctor": handle_doctor(get_best_value_in_keys(discharge_record_ie_result, DOCTOR)),
|
||||
"admission_id": handle_admission_id(get_best_value_in_keys(discharge_record_ie_result, ADMISSION_ID)),
|
||||
"age": handle_age(get_best_value_in_keys(discharge_record_ie_result, AGE)),
|
||||
"content": handle_text(discharge_record_ie_result['ocr_text']),
|
||||
"name": handle_name(get_best_value_of_key(discharge_record_ie_result, IE_KEY['name'])),
|
||||
"admission_date_str": handle_original_data(
|
||||
get_best_value_of_key(discharge_record_ie_result, IE_KEY['admission_date'])),
|
||||
"discharge_date_str": handle_original_data(
|
||||
get_best_value_of_key(discharge_record_ie_result, IE_KEY['discharge_date'])),
|
||||
"doctor": handle_name(get_best_value_of_key(discharge_record_ie_result, IE_KEY['doctor'])),
|
||||
"admission_id": handle_id(get_best_value_of_key(discharge_record_ie_result, IE_KEY['admission_id'])),
|
||||
"age": handle_age(get_best_value_of_key(discharge_record_ie_result, IE_KEY['age'])),
|
||||
}
|
||||
discharge_data["admission_date"] = handle_date(discharge_data["admission_date_str"])
|
||||
discharge_data["discharge_date"] = handle_date(discharge_data["discharge_date_str"])
|
||||
@@ -499,53 +418,270 @@ def discharge_task(pk_phhd, discharge_record, identity):
|
||||
if best_match:
|
||||
discharge_data["pk_ylks"] = best_match[2]
|
||||
save_or_update_ie(ZxIeDischarge, pk_phhd, discharge_data)
|
||||
return discharge_data
|
||||
|
||||
|
||||
def cost_task(pk_phhd, cost_list, identity):
|
||||
cost_list_ie_result = information_extraction(COST_IE, cost_list, identity)
|
||||
def cost_task(pk_phhd, cost_list_ie_result):
|
||||
cost_data = {
|
||||
"pk_phhd": pk_phhd,
|
||||
"name": handle_name(get_best_value_in_keys(cost_list_ie_result, PATIENT_NAME)),
|
||||
"admission_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, ADMISSION_DATE)),
|
||||
"discharge_date_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, DISCHARGE_DATE)),
|
||||
"medical_expenses_str": handle_original_data(get_best_value_in_keys(cost_list_ie_result, MEDICAL_EXPENSES)),
|
||||
"content": handle_text(cost_list_ie_result['ocr_text']),
|
||||
"name": handle_name(get_best_value_of_key(cost_list_ie_result, IE_KEY['name'])),
|
||||
"admission_date_str": handle_original_data(
|
||||
get_best_value_of_key(cost_list_ie_result, IE_KEY['admission_date'])),
|
||||
"discharge_date_str": handle_original_data(
|
||||
get_best_value_of_key(cost_list_ie_result, IE_KEY['discharge_date'])),
|
||||
"medical_expenses_str": handle_original_data(
|
||||
get_best_value_of_key(cost_list_ie_result, IE_KEY['medical_expenses']))
|
||||
}
|
||||
cost_data["admission_date"] = handle_date(cost_data["admission_date_str"])
|
||||
cost_data["discharge_date"] = handle_date(cost_data["discharge_date_str"])
|
||||
cost_data["medical_expenses"] = handle_decimal(cost_data["medical_expenses_str"])
|
||||
if cost_list_ie_result.get(IE_KEY['page']):
|
||||
page_nums, page_count = parse_page_num(cost_list_ie_result[IE_KEY['page']])
|
||||
if page_nums:
|
||||
page_nums_str = [str(num) for num in page_nums]
|
||||
cost_data['page_nums'] = handle_original_data(','.join(page_nums_str))
|
||||
cost_data['page_count'] = handle_tiny_int(page_count)
|
||||
save_or_update_ie(ZxIeCost, pk_phhd, cost_data)
|
||||
return cost_data
|
||||
|
||||
|
||||
def photo_review(pk_phhd):
|
||||
settlement_list = []
|
||||
discharge_record = []
|
||||
cost_list = []
|
||||
def parse_pdf_text(settlement_text):
|
||||
pattern = (r'(?:交款人:(.*?)\n|住院时间:(.*?)至(.*?)\n|\(小写\)(.*?)\n|个人现金支付:(.*?)\n|个人账户支付:(.*?)\n'
|
||||
r'|个人自费:(.*?)\n|医保类型:(.*?)\n|住院科别:(.*?)\n|住院号:(.*?)\n|票据号码:(.*?)\n|)')
|
||||
# 查找所有匹配项
|
||||
matches = re.findall(pattern, settlement_text)
|
||||
results = {}
|
||||
keys = ['患者姓名', '入院日期', '出院日期', '费用总额', '个人现金支付', '个人账户支付', '个人自费', '医保类型',
|
||||
'科室', '住院号', '医保结算单号码']
|
||||
|
||||
for match in matches:
|
||||
for key, value in zip(keys, match):
|
||||
if value:
|
||||
results[key] = [[{'text': value, 'probability': 1}]]
|
||||
settlement_key = ['患者姓名', '入院日期', '出院日期', '费用总额', '个人现金支付', '个人账户支付', '个人自费',
|
||||
'医保类型', '住院号', '医保结算单号码']
|
||||
discharge_key = ['科室', '患者姓名', '入院日期', '出院日期', '住院号']
|
||||
cost_key = ['患者姓名', '入院日期', '出院日期', '费用总额']
|
||||
settlement_result = {key: copy.copy(results[key]) for key in settlement_key if key in results}
|
||||
discharge_result = {key: copy.copy(results[key]) for key in discharge_key if key in results}
|
||||
cost_result = {key: copy.copy(results[key]) for key in cost_key if key in results}
|
||||
return settlement_result, discharge_result, cost_result
|
||||
|
||||
|
||||
def photo_review(pk_phhd, name):
|
||||
"""
|
||||
处理单个报销案子
|
||||
:param pk_phhd: 报销单主键
|
||||
:param name: 报销人姓名
|
||||
"""
|
||||
settlement_result = defaultdict(list)
|
||||
discharge_result = defaultdict(list)
|
||||
cost_result = defaultdict(list)
|
||||
|
||||
session = MysqlSession()
|
||||
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.pk_phhd, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
|
||||
phrecs = session.query(ZxPhrec.pk_phrec, ZxPhrec.cRectype, ZxPhrec.cfjaddress).filter(
|
||||
ZxPhrec.pk_phhd == pk_phhd
|
||||
).all()
|
||||
).order_by(ZxPhrec.cRectype, ZxPhrec.rowno).all()
|
||||
session.close()
|
||||
for phrec in phrecs:
|
||||
if phrec.cRectype == "1":
|
||||
settlement_list.append(phrec)
|
||||
elif phrec.cRectype == "3":
|
||||
discharge_record.append(phrec)
|
||||
elif phrec.cRectype == "4":
|
||||
cost_list.append(phrec)
|
||||
|
||||
# 同一批图的标识
|
||||
identity = int(time.time())
|
||||
settlement_task(pk_phhd, settlement_list, identity)
|
||||
discharge_task(pk_phhd, discharge_record, identity)
|
||||
cost_task(pk_phhd, cost_list, identity)
|
||||
set_batch_id(uuid.uuid4().hex)
|
||||
processed_img_dir = common_util.get_processed_img_path('')
|
||||
os.makedirs(processed_img_dir, exist_ok=True)
|
||||
|
||||
has_pdf = False # 是否获取到了pdf,获取到可以直接利用pdf更快的获取信息
|
||||
better_settlement_path = None
|
||||
better_cost_path = None
|
||||
settlement_text = ''
|
||||
qrcode_img_id = None
|
||||
for phrec in phrecs:
|
||||
original_img_path = common_util.get_img_path(phrec.cfjaddress)
|
||||
if not original_img_path:
|
||||
img_url = ufile.get_private_url(phrec.cfjaddress)
|
||||
if not img_url:
|
||||
continue
|
||||
original_img_path = common_util.save_to_local(img_url)
|
||||
img_path = common_util.get_processed_img_path(phrec.cfjaddress)
|
||||
shutil.copy2(original_img_path, img_path)
|
||||
# 尝试从二维码中获取高清图片
|
||||
better_settlement_path, settlement_text, better_cost_path = parse_qrcode(img_path, phrec.cfjaddress)
|
||||
if better_settlement_path:
|
||||
has_pdf = True
|
||||
qrcode_img_id = phrec.cfjaddress
|
||||
break
|
||||
|
||||
discharge_text = ''
|
||||
if has_pdf:
|
||||
settlement_result, discharge_result, cost_result = parse_pdf_text(settlement_text)
|
||||
discharge_ie_result = defaultdict(list)
|
||||
|
||||
is_cost_updated = False
|
||||
for phrec in phrecs:
|
||||
if phrec.cRectype == '1':
|
||||
if phrec.cfjaddress == qrcode_img_id:
|
||||
try:
|
||||
ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
ufile.upload_file(phrec.cfjaddress, better_settlement_path)
|
||||
except Exception as e:
|
||||
logging.error("更新结算单pdf图片出错", exc_info=e)
|
||||
elif phrec.cRectype == '3':
|
||||
rec_type, ie_result, ocr_text = information_extraction(phrec, pk_phhd)
|
||||
if rec_type == '出院记录':
|
||||
discharge_text += ocr_text
|
||||
for key, value in ie_result.items():
|
||||
discharge_ie_result[key].append(value)
|
||||
# 暂不替换费用清单
|
||||
# elif phrec.cRectype == '4':
|
||||
# if not is_cost_updated:
|
||||
# try:
|
||||
# ufile.copy_file(BUCKET, phrec.cfjaddress, "drg2015", phrec.cfjaddress)
|
||||
# ufile.upload_file(phrec.cfjaddress, better_cost_path)
|
||||
# except Exception as e:
|
||||
# logging.error("更新费用清单pdf图片出错", exc_info=e)
|
||||
# finally:
|
||||
# is_cost_updated = True
|
||||
|
||||
# 合并出院记录
|
||||
for key, value in discharge_ie_result.items():
|
||||
ie_value = get_best_value_of_key(discharge_ie_result, key)
|
||||
pdf_value = discharge_result.get(key)[0][0]['text'] if discharge_result.get(key) else ''
|
||||
similarity_ratio = fuzz.ratio(ie_value, pdf_value)
|
||||
if similarity_ratio < 60:
|
||||
discharge_result[key] = [[{'text': ie_value, 'probability': 1}]]
|
||||
else:
|
||||
for phrec in phrecs:
|
||||
rec_type, ie_result, ocr_text = information_extraction(phrec, pk_phhd)
|
||||
if rec_type == '基本医保结算单':
|
||||
rec_result = settlement_result
|
||||
elif rec_type == '出院记录':
|
||||
rec_result = discharge_result
|
||||
discharge_text += ocr_text
|
||||
elif rec_type == '费用清单':
|
||||
rec_result = cost_result
|
||||
else:
|
||||
rec_result = None
|
||||
if rec_result is not None:
|
||||
for key, value in ie_result.items():
|
||||
rec_result[key].append(value)
|
||||
|
||||
# 删除多余图片
|
||||
if os.path.exists(processed_img_dir) and os.path.isdir(processed_img_dir):
|
||||
shutil.rmtree(processed_img_dir)
|
||||
|
||||
settlement_data = settlement_task(pk_phhd, settlement_result)
|
||||
discharge_data = discharge_task(pk_phhd, discharge_result)
|
||||
cost_data = cost_task(pk_phhd, cost_result)
|
||||
|
||||
# 三项资料完整性判断
|
||||
# 三项资料缺项判断
|
||||
review_result = {
|
||||
'pk_phhd': pk_phhd,
|
||||
'has_settlement': bool(settlement_result),
|
||||
'has_discharge': bool(discharge_result),
|
||||
'has_cost': bool(cost_result),
|
||||
}
|
||||
if (review_result['has_settlement'] and settlement_data.get('personal_account_payment')
|
||||
and settlement_data.get('personal_cash_payment') and settlement_data.get('medical_expenses')):
|
||||
review_result['has_settlement'] &= (
|
||||
float(settlement_data['personal_account_payment']) + float(settlement_data['personal_cash_payment'])
|
||||
< float(settlement_data['medical_expenses'])
|
||||
)
|
||||
if has_pdf:
|
||||
review_result['has_discharge'] &= bool(discharge_text)
|
||||
|
||||
# 三项资料缺页判断
|
||||
page_description = []
|
||||
if review_result['has_discharge']:
|
||||
for discharge_item in DISCHARGE_KEY:
|
||||
if not any(key in discharge_text for key in DISCHARGE_KEY[discharge_item]):
|
||||
page_description.append(f"《出院记录》缺页")
|
||||
break
|
||||
|
||||
if review_result['has_cost']:
|
||||
cost_missing_page = {}
|
||||
if cost_data.get('page_nums') and cost_data.get('page_count'):
|
||||
page_nums = cost_data['page_nums'].split(',')
|
||||
required_set = set(range(1, cost_data['page_count'] + 1))
|
||||
page_set = set([int(num) for num in page_nums])
|
||||
cost_missing_page = required_set - page_set
|
||||
if cost_missing_page:
|
||||
cost_missing_page = sorted(cost_missing_page)
|
||||
cost_missing_page = [str(num) for num in cost_missing_page]
|
||||
page_description.append(f"《住院费用清单》,缺第{','.join(cost_missing_page)}页")
|
||||
|
||||
if page_description:
|
||||
review_result['full_page'] = False
|
||||
review_result['page_description'] = ';'.join(page_description)
|
||||
else:
|
||||
review_result['full_page'] = True
|
||||
|
||||
review_result['integrity'] = (review_result['has_settlement'] and review_result['has_discharge']
|
||||
and review_result['has_cost'] and review_result['full_page'])
|
||||
|
||||
# 三项资料一致性判断
|
||||
# 姓名一致性
|
||||
name_list = [settlement_data['name'], discharge_data['name'], cost_data['name']]
|
||||
if sum(not bool(n) for n in name_list) > 1: # 有2个及以上空值直接认为都不一致
|
||||
review_result['name_match'] = '0'
|
||||
else:
|
||||
unique_name = set(name_list)
|
||||
if len(unique_name) == 1:
|
||||
review_result['name_match'] = '1' if name == unique_name.pop() else '5'
|
||||
elif len(unique_name) == 2:
|
||||
if settlement_data['name'] != discharge_data['name'] and settlement_data['name'] != cost_data['name']:
|
||||
review_result['name_match'] = '2'
|
||||
elif discharge_data['name'] != settlement_data['name'] and discharge_data['name'] != cost_data['name']:
|
||||
review_result['name_match'] = '3'
|
||||
else:
|
||||
review_result['name_match'] = '4'
|
||||
else:
|
||||
review_result['name_match'] = '0'
|
||||
|
||||
# 住院日期一致性
|
||||
if (settlement_data['admission_date'] and discharge_data['admission_date']
|
||||
and settlement_data['discharge_date'] and discharge_data['discharge_date']
|
||||
and settlement_data['admission_date'] == discharge_data['admission_date']
|
||||
and settlement_data['discharge_date'] == discharge_data['discharge_date']):
|
||||
review_result['admission_date_match'] = '1'
|
||||
else:
|
||||
review_result['admission_date_match'] = '0'
|
||||
|
||||
# 出院日期一致性
|
||||
discharge_date_list = [settlement_data['discharge_date'], discharge_data['discharge_date'],
|
||||
cost_data['discharge_date']]
|
||||
if sum(not bool(d) for d in discharge_date_list) > 1:
|
||||
review_result['discharge_date_match'] = '0'
|
||||
else:
|
||||
unique_discharge_date = set(discharge_date_list)
|
||||
if len(unique_discharge_date) == 1:
|
||||
review_result['discharge_date_match'] = '1'
|
||||
elif len(unique_discharge_date) == 2:
|
||||
if (settlement_data['discharge_date'] != discharge_data['discharge_date']
|
||||
and settlement_data['discharge_date'] != cost_data['discharge_date']):
|
||||
review_result['discharge_date_match'] = '2'
|
||||
elif (discharge_data['discharge_date'] != settlement_data['discharge_date']
|
||||
and discharge_data['discharge_date'] != cost_data['discharge_date']):
|
||||
review_result['discharge_date_match'] = '3'
|
||||
else:
|
||||
review_result['discharge_date_match'] = '4'
|
||||
else:
|
||||
review_result['discharge_date_match'] = '0'
|
||||
|
||||
review_result['consistency'] = (
|
||||
review_result['name_match'] == '1' and review_result['admission_date_match'] == '1'
|
||||
and review_result['discharge_date_match'] == '1')
|
||||
|
||||
review_result['success'] = review_result['integrity'] and review_result['consistency']
|
||||
save_or_update_ie(ZxIeReview, pk_phhd, review_result)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
照片审核批量控制
|
||||
"""
|
||||
while 1:
|
||||
session = MysqlSession()
|
||||
phhds = (session.query(ZxPhhd.pk_phhd)
|
||||
phhds = (session.query(ZxPhhd.pk_phhd, ZxPhhd.cXm)
|
||||
.join(ZxPhrec, ZxPhhd.pk_phhd == ZxPhrec.pk_phhd, isouter=True)
|
||||
.filter(ZxPhhd.exsuccess_flag == "1")
|
||||
.filter(ZxPhrec.pk_phrec.isnot(None))
|
||||
@@ -562,14 +698,14 @@ def main():
|
||||
pk_phhd = phhd.pk_phhd
|
||||
logging.info(f"开始识别:{pk_phhd}")
|
||||
start_time = time.time()
|
||||
photo_review(pk_phhd)
|
||||
photo_review(pk_phhd, phhd.cXm)
|
||||
|
||||
# 识别完成更新标识
|
||||
session = MysqlSession()
|
||||
update_flag = (update(ZxPhhd).where(ZxPhhd.pk_phhd == pk_phhd).values(
|
||||
exsuccess_flag="8",
|
||||
ref_id1=HOSTNAME,
|
||||
checktime=util.get_default_datetime(),
|
||||
checktime=common_util.get_default_datetime(),
|
||||
fFSYLFY=time.time() - start_time))
|
||||
session.execute(update_flag)
|
||||
session.commit()
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.sql.functions import count
|
||||
|
||||
from db import MysqlSession
|
||||
from db.mysql import ZxPhhd, ViewErrorReview
|
||||
from util import util
|
||||
from util import common_util
|
||||
|
||||
|
||||
def handle_reason(reason):
|
||||
@@ -95,5 +95,5 @@ if __name__ == '__main__':
|
||||
print(result)
|
||||
with open("photo_review_error_report.txt", 'w', encoding='utf-8') as file:
|
||||
file.write(json.dumps(result, indent=4, ensure_ascii=False))
|
||||
file.write(util.get_default_datetime())
|
||||
file.write(common_util.get_default_datetime())
|
||||
print("结果已保存。")
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
beautifulsoup4==4.12.3 # 网页分析
|
||||
jieba==0.42.1 # 中文分词
|
||||
numpy==1.26.4
|
||||
onnxconverter-common==1.14.0
|
||||
OpenCC==1.1.6
|
||||
OpenCC==1.1.9 # 中文繁简转换
|
||||
opencv-python==4.6.0.66
|
||||
paddle2onnx==1.2.3
|
||||
paddleclas==2.5.2
|
||||
paddlenlp==2.6.1
|
||||
paddleocr==2.7.3
|
||||
opencv-python-headless==4.10.0.84
|
||||
pillow==10.4.0
|
||||
PyMuPDF==1.24.9 # pdf处理
|
||||
pymysql==1.1.1
|
||||
rapidfuzz==3.9.4 #文本相似度
|
||||
requests==2.32.3
|
||||
sqlacodegen==2.3.0.post1
|
||||
sqlalchemy==1.4.52
|
||||
tenacity==8.5.0
|
||||
ufile==3.2.9
|
||||
zxing-cpp==2.2.0
|
||||
sqlacodegen==2.3.0.post1 # 实体类生成
|
||||
sqlalchemy==1.4.52 # ORM框架
|
||||
tenacity==8.5.0 # 重试
|
||||
ufile==3.2.9 # 云空间
|
||||
zxing-cpp==2.2.0 # 二维码识别
|
||||
245
services/paddle_services/.dockerignore
Normal file
245
services/paddle_services/.dockerignore
Normal file
@@ -0,0 +1,245 @@
|
||||
### PyCharm template
|
||||
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
||||
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
||||
|
||||
# User-specific stuff
|
||||
.idea/**/workspace.xml
|
||||
.idea/**/tasks.xml
|
||||
.idea/**/usage.statistics.xml
|
||||
.idea/**/dictionaries
|
||||
.idea/**/shelf
|
||||
|
||||
# AWS User-specific
|
||||
.idea/**/aws.xml
|
||||
|
||||
# Generated files
|
||||
.idea/**/contentModel.xml
|
||||
|
||||
# Sensitive or high-churn files
|
||||
.idea/**/dataSources/
|
||||
.idea/**/dataSources.ids
|
||||
.idea/**/dataSources.local.xml
|
||||
.idea/**/sqlDataSources.xml
|
||||
.idea/**/dynamic.xml
|
||||
.idea/**/uiDesigner.xml
|
||||
.idea/**/dbnavigator.xml
|
||||
|
||||
# Gradle
|
||||
.idea/**/gradle.xml
|
||||
.idea/**/libraries
|
||||
|
||||
# Gradle and Maven with auto-import
|
||||
# When using Gradle or Maven with auto-import, you should exclude module files,
|
||||
# since they will be recreated, and may cause churn. Uncomment if using
|
||||
# auto-import.
|
||||
# .idea/artifacts
|
||||
# .idea/compiler.xml
|
||||
# .idea/jarRepositories.xml
|
||||
# .idea/modules.xml
|
||||
# .idea/*.iml
|
||||
# .idea/modules
|
||||
# *.iml
|
||||
# *.ipr
|
||||
|
||||
# CMake
|
||||
cmake-build-*/
|
||||
|
||||
# Mongo Explorer plugin
|
||||
.idea/**/mongoSettings.xml
|
||||
|
||||
# File-based project format
|
||||
*.iws
|
||||
|
||||
# IntelliJ
|
||||
out/
|
||||
|
||||
# mpeltonen/sbt-idea plugin
|
||||
.idea_modules/
|
||||
|
||||
# JIRA plugin
|
||||
atlassian-ide-plugin.xml
|
||||
|
||||
# Cursive Clojure plugin
|
||||
.idea/replstate.xml
|
||||
|
||||
# SonarLint plugin
|
||||
.idea/sonarlint/
|
||||
|
||||
# Crashlytics plugin (for Android Studio and IntelliJ)
|
||||
com_crashlytics_export_strings.xml
|
||||
crashlytics.properties
|
||||
crashlytics-build.properties
|
||||
fabric.properties
|
||||
|
||||
# Editor-based Rest Client
|
||||
.idea/httpRequests
|
||||
|
||||
# Android studio 3.1+ serialized cache file
|
||||
.idea/caches/build_file_checksums.ser
|
||||
|
||||
### Python template
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# 通过卷绑定挂载到容器中
|
||||
/log
|
||||
/model
|
||||
# docker
|
||||
Dockerfile
|
||||
28
services/paddle_services/Dockerfile
Normal file
28
services/paddle_services/Dockerfile
Normal file
@@ -0,0 +1,28 @@
|
||||
# 使用官方的paddle镜像作为基础
|
||||
FROM registry.baidubce.com/paddlepaddle/paddle:2.6.1-gpu-cuda12.0-cudnn8.9-trt8.6
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 设置环境变量
|
||||
ENV PYTHONUNBUFFERED=1 \
|
||||
# 设置时区
|
||||
TZ=Asia/Shanghai \
|
||||
# 设置pip镜像地址,加快安装速度
|
||||
PIP_INDEX_URL=https://pypi.tuna.tsinghua.edu.cn/simple
|
||||
|
||||
# 安装依赖
|
||||
COPY requirements.txt /app/requirements.txt
|
||||
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo '$TZ' > /etc/timezone \
|
||||
&& pip install --no-cache-dir -r requirements.txt \
|
||||
&& pip uninstall -y onnxruntime onnxruntime-gpu \
|
||||
&& pip install onnxruntime-gpu==1.18.0 --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
|
||||
|
||||
# 将当前目录内容复制到容器的/app内
|
||||
COPY . /app
|
||||
|
||||
# 暴露端口
|
||||
# EXPOSE 8081
|
||||
|
||||
# 运行api接口,具体接口在命令行或docker-compose.yml文件中定义
|
||||
ENTRYPOINT ["gunicorn"]
|
||||
21
services/paddle_services/__init__.py
Normal file
21
services/paddle_services/__init__.py
Normal file
@@ -0,0 +1,21 @@
|
||||
"""
|
||||
信息抽取关键词配置
|
||||
"""
|
||||
IE_KEY = {
|
||||
'name': '患者姓名',
|
||||
'admission_date': '入院日期',
|
||||
'discharge_date': '出院日期',
|
||||
'medical_expenses': '费用总额',
|
||||
'personal_cash_payment': '个人现金支付',
|
||||
'personal_account_payment': '个人账户支付',
|
||||
'personal_funded_amount': '自费金额',
|
||||
'medical_insurance_type': '医保类型',
|
||||
'hospital': '医院',
|
||||
'department': '科室',
|
||||
'doctor': '主治医生',
|
||||
'admission_id': '住院号',
|
||||
'settlement_id': '医保结算单号码',
|
||||
'age': '年龄',
|
||||
'uppercase_medical_expenses': '大写总额',
|
||||
'page': '页码',
|
||||
}
|
||||
30
services/paddle_services/clas_orientation.py
Normal file
30
services/paddle_services/clas_orientation.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddleclas import PaddleClas
|
||||
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
CLAS = PaddleClas(model_name='text_image_orientation')
|
||||
|
||||
|
||||
@app.route(rule='/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
"""
|
||||
判断图片旋转角度,逆时针旋转该角度后为正。可能值['0', '90', '180', '270']
|
||||
:return: 最有可能的两个角度
|
||||
"""
|
||||
img_path = request.form.get('img_path')
|
||||
clas_result = CLAS.predict(input_data=img_path)
|
||||
clas_result = next(clas_result)[0]
|
||||
if clas_result['scores'][0] < 0.5:
|
||||
return ['0', '90']
|
||||
return clas_result['label_names']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5005)
|
||||
31
services/paddle_services/clas_text.py
Normal file
31
services/paddle_services/clas_text.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddlenlp import Taskflow
|
||||
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
schema = ['基本医保结算单', '出院记录', '费用清单']
|
||||
CLAS = Taskflow('zero_shot_text_classification', model='utc-xbase', schema=schema,
|
||||
task_path='model/text_classification', precision='fp32')
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
text = request.form.get('text')
|
||||
cls_result = CLAS(text)
|
||||
cls_result = cls_result[0].get('predictions')
|
||||
if cls_result:
|
||||
cls_result = cls_result[0]
|
||||
if cls_result['score'] and float(cls_result['score']) < 0.8:
|
||||
logging.info(f"识别结果置信度{cls_result['score']}过低!text: {text}")
|
||||
return None
|
||||
return cls_result['label']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5008)
|
||||
31
services/paddle_services/det_book.py
Normal file
31
services/paddle_services/det_book.py
Normal file
@@ -0,0 +1,31 @@
|
||||
import logging.config
|
||||
import os.path
|
||||
|
||||
import cv2
|
||||
from flask import Flask, request
|
||||
|
||||
from log import LOGGING_CONFIG
|
||||
from paddle_detection import detector
|
||||
from utils import process_request, parse_img_path
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
result = detector.get_book_areas(img_path)
|
||||
|
||||
dirname, img_name, img_ext = parse_img_path(img_path)
|
||||
books_path = []
|
||||
for i in range(len(result)):
|
||||
save_path = os.path.join(dirname, f'{img_name}.book_{i}.{img_ext}')
|
||||
cv2.imwrite(save_path, result[i])
|
||||
books_path.append(save_path)
|
||||
return books_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5006)
|
||||
28
services/paddle_services/dewarp.py
Normal file
28
services/paddle_services/dewarp.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import logging.config
|
||||
import os
|
||||
|
||||
import cv2
|
||||
from flask import Flask, request
|
||||
|
||||
from doc_dewarp import dewarper
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request, parse_img_path
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
img = cv2.imread(img_path)
|
||||
dewarped_img = dewarper.dewarp_image(img)
|
||||
dirname, img_name, img_ext = parse_img_path(img_path)
|
||||
save_path = os.path.join(dirname, f'{img_name}.dewarped.{img_ext}')
|
||||
cv2.imwrite(save_path, dewarped_img)
|
||||
return save_path
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5007)
|
||||
167
services/paddle_services/doc_dewarp/.gitignore
vendored
Normal file
167
services/paddle_services/doc_dewarp/.gitignore
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
input/
|
||||
output/
|
||||
|
||||
runs/
|
||||
*_dump/
|
||||
*_log/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
34
services/paddle_services/doc_dewarp/.pre-commit-config.yaml
Normal file
34
services/paddle_services/doc_dewarp/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
repos:
|
||||
# Common hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
||||
rev: v1.5.1
|
||||
hooks:
|
||||
- id: remove-crlf
|
||||
- id: remove-tabs
|
||||
name: Tabs remover (Python)
|
||||
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
args: [--whitespaces-count, '4']
|
||||
# For Python files
|
||||
- repo: https://github.com/psf/black.git
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.0.272
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --no-cache]
|
||||
398
services/paddle_services/doc_dewarp/GeoTr.py
Normal file
398
services/paddle_services/doc_dewarp/GeoTr.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from .extractor import BasicEncoder
|
||||
from .position_encoding import build_position_encoding
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class attnLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead=8,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn_list = nn.LayerList(
|
||||
[
|
||||
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
|
||||
for i in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p=dropout)
|
||||
self.dropout2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
|
||||
)
|
||||
self.dropout3 = nn.Dropout(p=dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, pos)
|
||||
tgt2 = self.self_attn(
|
||||
q.transpose((1, 0, 2)),
|
||||
k.transpose((1, 0, 2)),
|
||||
value=tgt.transpose((1, 0, 2)),
|
||||
attn_mask=tgt_mask,
|
||||
)
|
||||
tgt2 = tgt2.transpose((1, 0, 2))
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
|
||||
memory_list,
|
||||
self.multihead_attn_list,
|
||||
self.norm2_list,
|
||||
self.dropout2_list,
|
||||
memory_pos,
|
||||
):
|
||||
tgt2 = multihead_attn(
|
||||
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
|
||||
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
|
||||
value=memory.transpose((1, 0, 2)),
|
||||
attn_mask=memory_mask,
|
||||
).transpose((1, 0, 2))
|
||||
|
||||
tgt = tgt + dropout2(tgt2)
|
||||
tgt = norm2(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
|
||||
q = k = self.with_pos_embed(tgt2, pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, pos),
|
||||
key=self.with_pos_embed(memory, memory_pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransDecoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransDecoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return query_embed
|
||||
|
||||
|
||||
class TransEncoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransEncoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class FlowHead(nn.Layer):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class UpdateBlock(nn.Layer):
|
||||
def __init__(self, hidden_dim: int = 128):
|
||||
super(UpdateBlock, self).__init__()
|
||||
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2D(hidden_dim, 256, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2D(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, image, coords):
|
||||
mask = 0.25 * self.mask(image)
|
||||
dflow = self.flow_head(image)
|
||||
coords = coords + dflow
|
||||
return mask, coords
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
|
||||
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
|
||||
return coords[None].tile([batch, 1, 1, 1])
|
||||
|
||||
|
||||
def upflow8(flow, mode="bilinear"):
|
||||
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Layer):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
|
||||
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
||||
patch_size = (
|
||||
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||
)
|
||||
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
|
||||
self.proj = nn.Conv2D(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
||||
)
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
weight_init_(m, "trunc_normal_", std=0.02)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
weight_init_(m, "Constant", value=1.0)
|
||||
elif isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2)
|
||||
|
||||
perm = list(range(x.ndim))
|
||||
perm[1] = 2
|
||||
perm[2] = 1
|
||||
x = x.transpose(perm=perm)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class GeoTr(nn.Layer):
|
||||
def __init__(self):
|
||||
super(GeoTr, self).__init__()
|
||||
|
||||
self.hidden_dim = hdim = 256
|
||||
|
||||
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
|
||||
|
||||
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.encoder_block:
|
||||
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
|
||||
|
||||
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
|
||||
for i in self.down_layer:
|
||||
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
|
||||
|
||||
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.decoder_block:
|
||||
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
|
||||
|
||||
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
|
||||
for i in self.up_layer:
|
||||
self.__setattr__(
|
||||
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||
)
|
||||
|
||||
self.query_embed = nn.Embedding(81, self.hidden_dim)
|
||||
|
||||
self.update_block = UpdateBlock(self.hidden_dim)
|
||||
|
||||
def initialize_flow(self, img):
|
||||
N, _, H, W = img.shape
|
||||
coodslar = coords_grid(N, H, W)
|
||||
coords0 = coords_grid(N, H // 8, W // 8)
|
||||
coords1 = coords_grid(N, H // 8, W // 8)
|
||||
return coodslar, coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
N, _, H, W = flow.shape
|
||||
|
||||
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
|
||||
mask = F.softmax(mask, axis=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
|
||||
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
|
||||
|
||||
up_flow = paddle.sum(mask * up_flow, axis=2)
|
||||
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
|
||||
|
||||
return up_flow.reshape([N, 2, 8 * H, 8 * W])
|
||||
|
||||
def forward(self, image):
|
||||
fmap = self.fnet(image)
|
||||
fmap = F.relu(fmap)
|
||||
|
||||
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
|
||||
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
|
||||
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
|
||||
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
|
||||
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
|
||||
|
||||
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
|
||||
|
||||
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
|
||||
fmap3du_ = (
|
||||
self.__getattr__(self.up_layer[0])(fmap3d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
|
||||
fmap2du_ = (
|
||||
self.__getattr__(self.up_layer[1])(fmap2d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
|
||||
|
||||
coodslar, coords0, coords1 = self.initialize_flow(image)
|
||||
coords1 = coords1.detach()
|
||||
mask, coords1 = self.update_block(fmap_out, coords1)
|
||||
flow_up = self.upsample_flow(coords1 - coords0, mask)
|
||||
bm_up = coodslar + flow_up
|
||||
|
||||
return bm_up
|
||||
73
services/paddle_services/doc_dewarp/README.md
Normal file
73
services/paddle_services/doc_dewarp/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# DocTrPP: DocTr++ in PaddlePaddle
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
|
||||
|
||||

|
||||
|
||||
## Requirements
|
||||
|
||||
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
|
||||
|
||||
## Training
|
||||
|
||||
1. Data Preparation
|
||||
|
||||
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
|
||||
|
||||
2. Training
|
||||
|
||||
```shell
|
||||
sh train.sh
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
3. Load Trained Model and Continue Training
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--resume "runs/train/DocTr++/weights/last.ckpt" \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
## Test and Inference
|
||||
|
||||
Test the dewarp result on a single image:
|
||||
|
||||
```shell
|
||||
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
|
||||
```
|
||||

|
||||
|
||||
## Export to onnx
|
||||
|
||||
```
|
||||
pip install paddle2onnx
|
||||
|
||||
python export.py -m ./best.ckpt --format onnx
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).
|
||||
7
services/paddle_services/doc_dewarp/__init__.py
Normal file
7
services/paddle_services/doc_dewarp/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
MODEL_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
'model', 'dewarp_model', 'doc_tr_pp.onnx')
|
||||
DOC_TR = InferenceSession(MODEL_PATH, providers=['CUDAExecutionProvider'], provider_options=[{'device_id': 0}])
|
||||
133
services/paddle_services/doc_dewarp/data_visualization.py
Normal file
133
services/paddle_services/doc_dewarp/data_visualization.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d/",
|
||||
help="Path to the downloaded dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="output.png",
|
||||
help="Output filename for the image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
root = os.path.expanduser(args.data_root)
|
||||
folder = args.folder
|
||||
dirname = os.path.join(root, "img", str(folder))
|
||||
|
||||
choices = [f for f in os.listdir(dirname) if "png" in f]
|
||||
fname = random.choice(choices)
|
||||
|
||||
# Read Image
|
||||
img_path = os.path.join(dirname, fname)
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# Read 3D Coords
|
||||
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
# scale wc
|
||||
# value obtained from the entire dataset
|
||||
xmx, xmn, ymx, ymn, zmx, zmn = (
|
||||
1.2539363,
|
||||
-1.2442188,
|
||||
1.2396319,
|
||||
-1.2289206,
|
||||
0.6436657,
|
||||
-0.67492497,
|
||||
)
|
||||
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
|
||||
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
|
||||
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
|
||||
|
||||
# Read Backward Map
|
||||
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
# Read UV Map
|
||||
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
|
||||
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Depth Map
|
||||
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
|
||||
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
|
||||
# do some clipping and scaling to display it
|
||||
dmap[dmap > 30.0] = 30
|
||||
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
|
||||
|
||||
# Read Normal Map
|
||||
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
|
||||
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Albedo
|
||||
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
|
||||
alb = cv2.imread(alb_path)
|
||||
|
||||
# Read Checkerboard Image
|
||||
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
|
||||
recon = cv2.imread(recon_path)
|
||||
|
||||
# Display image and GTs
|
||||
|
||||
# use the backward mapping to dewarp the image
|
||||
# scale bm to -1.0 to 1.0
|
||||
bm_ = bm / np.array([448, 448])
|
||||
bm_ = (bm_ - 0.5) * 2
|
||||
bm_ = np.reshape(bm_, (1, 448, 448, 2))
|
||||
bm_ = paddle.to_tensor(bm_, dtype="float32")
|
||||
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
|
||||
img_ = np.expand_dims(img_, 0)
|
||||
img_ = paddle.to_tensor(img_, dtype="float32")
|
||||
uw = F.grid_sample(img_, bm_)
|
||||
uw = uw[0].numpy().transpose((1, 2, 0))
|
||||
|
||||
f, axrr = plt.subplots(2, 5)
|
||||
for ax in axrr:
|
||||
for a in ax:
|
||||
a.set_xticks([])
|
||||
a.set_yticks([])
|
||||
|
||||
axrr[0][0].imshow(img)
|
||||
axrr[0][0].title.set_text("image")
|
||||
axrr[0][1].imshow(wc)
|
||||
axrr[0][1].title.set_text("3D coords")
|
||||
axrr[0][2].imshow(bm[:, :, 0])
|
||||
axrr[0][2].title.set_text("bm 0")
|
||||
axrr[0][3].imshow(bm[:, :, 1])
|
||||
axrr[0][3].title.set_text("bm 1")
|
||||
if uv is None:
|
||||
uv = np.zeros_like(img)
|
||||
axrr[0][4].imshow(uv)
|
||||
axrr[0][4].title.set_text("uv map")
|
||||
axrr[1][0].imshow(dmap)
|
||||
axrr[1][0].title.set_text("depth map")
|
||||
axrr[1][1].imshow(norm)
|
||||
axrr[1][1].title.set_text("normal map")
|
||||
axrr[1][2].imshow(alb)
|
||||
axrr[1][2].title.set_text("albedo")
|
||||
axrr[1][3].imshow(recon)
|
||||
axrr[1][3].title.set_text("checkerboard")
|
||||
axrr[1][4].imshow(uw)
|
||||
axrr[1][4].title.set_text("gt unwarped")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.output)
|
||||
21
services/paddle_services/doc_dewarp/dewarper.py
Normal file
21
services/paddle_services/doc_dewarp/dewarper.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from . import DOC_TR
|
||||
from .utils import to_tensor, to_image
|
||||
|
||||
|
||||
def dewarp_image(image):
|
||||
img = cv2.resize(image, (288, 288)).astype(np.float32)
|
||||
y = to_tensor(image)
|
||||
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
bm = DOC_TR.run(None, {'image': img[None,]})[0]
|
||||
bm = paddle.to_tensor(bm)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode='bilinear', align_corners=False
|
||||
)
|
||||
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
return to_image(out)
|
||||
161
services/paddle_services/doc_dewarp/doc/download_dataset.sh
Normal file
161
services/paddle_services/doc_dewarp/doc/download_dataset.sh
Normal file
@@ -0,0 +1,161 @@
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
129
services/paddle_services/doc_dewarp/doc3d_dataset.py
Normal file
129
services/paddle_services/doc_dewarp/doc3d_dataset.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import collections
|
||||
import os
|
||||
import random
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import io
|
||||
|
||||
# Set random seed
|
||||
random.seed(12345678)
|
||||
|
||||
|
||||
class Doc3dDataset(io.Dataset):
|
||||
def __init__(self, root, split="train", is_augment=False, image_size=512):
|
||||
self.root = os.path.expanduser(root)
|
||||
|
||||
self.split = split
|
||||
self.is_augment = is_augment
|
||||
|
||||
self.files = collections.defaultdict(list)
|
||||
|
||||
self.image_size = (
|
||||
image_size if isinstance(image_size, tuple) else (image_size, image_size)
|
||||
)
|
||||
|
||||
# Augmentation
|
||||
self.augmentation = A.Compose(
|
||||
[
|
||||
A.ColorJitter(),
|
||||
]
|
||||
)
|
||||
|
||||
for split in ["train", "val"]:
|
||||
path = os.path.join(self.root, split + ".txt")
|
||||
file_list = []
|
||||
with open(path, "r") as file:
|
||||
file_list = [file_id.rstrip() for file_id in file.readlines()]
|
||||
self.files[split] = file_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files[self.split])
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_name = self.files[self.split][index]
|
||||
|
||||
# Read image
|
||||
image_path = os.path.join(self.root, "img", image_name + ".png")
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
# Read 3D Coordinates
|
||||
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read backward map
|
||||
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
image, bm = self.transform(wc, bm, image)
|
||||
|
||||
return image, bm
|
||||
|
||||
def tight_crop(self, wc: np.ndarray):
|
||||
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_size = mask.shape
|
||||
[y, x] = mask.nonzero()
|
||||
min_x = min(x)
|
||||
max_x = max(x)
|
||||
min_y = min(y)
|
||||
max_y = max(y)
|
||||
|
||||
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
|
||||
s = 10
|
||||
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
|
||||
|
||||
cx1 = random.randint(0, 2 * s)
|
||||
cx2 = random.randint(0, 2 * s) + 1
|
||||
cy1 = random.randint(0, 2 * s)
|
||||
cy2 = random.randint(0, 2 * s) + 1
|
||||
|
||||
wc = wc[cy1:-cy2, cx1:-cx2, :]
|
||||
|
||||
top: int = min_y - s + cy1
|
||||
bottom: int = mask_size[0] - max_y - s + cy2
|
||||
left: int = min_x - s + cx1
|
||||
right: int = mask_size[1] - max_x - s + cx2
|
||||
|
||||
top = np.clip(top, 0, mask_size[0])
|
||||
bottom = np.clip(bottom, 1, mask_size[0] - 1)
|
||||
left = np.clip(left, 0, mask_size[1])
|
||||
right = np.clip(right, 1, mask_size[1] - 1)
|
||||
|
||||
return wc, top, bottom, left, right
|
||||
|
||||
def transform(self, wc, bm, img):
|
||||
wc, top, bottom, left, right = self.tight_crop(wc)
|
||||
|
||||
img = img[top:-bottom, left:-right, :]
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.is_augment:
|
||||
img = self.augmentation(image=img)["image"]
|
||||
|
||||
# resize image
|
||||
img = cv2.resize(img, self.image_size)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
|
||||
# resize bm
|
||||
bm = bm.astype(np.float32)
|
||||
bm[:, :, 1] = bm[:, :, 1] - top
|
||||
bm[:, :, 0] = bm[:, :, 0] - left
|
||||
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
|
||||
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
|
||||
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
|
||||
bm0 = bm0 * self.image_size[0]
|
||||
bm1 = bm1 * self.image_size[1]
|
||||
|
||||
bm = np.stack([bm0, bm1], axis=-1)
|
||||
|
||||
img = paddle.to_tensor(img).astype(dtype="float32")
|
||||
bm = paddle.to_tensor(bm).astype(dtype="float32")
|
||||
|
||||
return img, bm
|
||||
66
services/paddle_services/doc_dewarp/export.py
Normal file
66
services/paddle_services/doc_dewarp/export.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def export(args):
|
||||
model_path = args.model
|
||||
imgsz = args.imgsz
|
||||
format = args.format
|
||||
|
||||
model = GeoTr()
|
||||
checkpoint = paddle.load(model_path)
|
||||
model.set_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
|
||||
dirname = os.path.dirname(model_path)
|
||||
if format == "static" or format == "onnx":
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
|
||||
],
|
||||
full_graph=True,
|
||||
)
|
||||
paddle.jit.save(model, os.path.join(dirname, "model"))
|
||||
|
||||
if format == "onnx":
|
||||
onnx_path = os.path.join(dirname, "model.onnx")
|
||||
os.system(
|
||||
f"paddle2onnx --model_dir {dirname}"
|
||||
" --model_filename model.pdmodel"
|
||||
" --params_filename model.pdiparams"
|
||||
f" --save_file {onnx_path}"
|
||||
" --opset_version 11"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="export model")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--imgsz", type=int, default=288, help="The size of input image"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="static",
|
||||
help="The format of exported model, which can be static or onnx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export(args)
|
||||
110
services/paddle_services/doc_dewarp/extractor.py
Normal file
110
services/paddle_services/doc_dewarp/extractor.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import paddle.nn as nn
|
||||
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""Residual Block with custom normalization."""
|
||||
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
if norm_fn == "group":
|
||||
num_groups = planes // 8
|
||||
self.norm1 = nn.GroupNorm(num_groups, planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups, planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups, planes)
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(planes)
|
||||
self.norm2 = nn.BatchNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2D(planes)
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(planes)
|
||||
self.norm2 = nn.InstanceNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2D(planes)
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Layer):
|
||||
"""Basic Encoder with custom normalization."""
|
||||
|
||||
def __init__(self, output_dim=128, norm_fn="batch"):
|
||||
super(BasicEncoder, self).__init__()
|
||||
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(8, 64)
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(64)
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(64)
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(128, stride=2)
|
||||
self.layer3 = self._make_layer(192, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2D(192, output_dim, 1)
|
||||
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
|
||||
weight_init_(m, "Constant", value=1, bias_value=0.0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = layer1, layer2
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
48
services/paddle_services/doc_dewarp/plots.py
Normal file
48
services/paddle_services/doc_dewarp/plots.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import paddle.optimizer as optim
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
|
||||
"""
|
||||
Plot the learning rate scheduler
|
||||
"""
|
||||
|
||||
optimizer = copy.copy(optimizer)
|
||||
scheduler = copy.copy(scheduler)
|
||||
|
||||
lr = []
|
||||
for _ in range(epochs):
|
||||
for _ in range(30):
|
||||
lr.append(scheduler.get_lr())
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch = [float(i) / 30.0 for i in range(len(lr))]
|
||||
|
||||
plt.figure()
|
||||
plt.plot(epoch, lr, ".-", label="Learning Rate")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("Learning Rate")
|
||||
plt.title("Learning Rate Scheduler")
|
||||
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = GeoTr()
|
||||
|
||||
schaduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=1e-4,
|
||||
total_steps=1950,
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=1e-4 / 2.5e5,
|
||||
)
|
||||
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
|
||||
plot_lr_scheduler(
|
||||
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
|
||||
)
|
||||
124
services/paddle_services/doc_dewarp/position_encoding.py
Normal file
124
services/paddle_services/doc_dewarp/position_encoding.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.initializer as init
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Layer):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, mask):
|
||||
assert mask is not None
|
||||
|
||||
y_embed = mask.cumsum(axis=1, dtype="float32")
|
||||
x_embed = mask.cumsum(axis=2, dtype="float32")
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-06
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
|
||||
pos_x = paddle.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
pos_y = paddle.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
|
||||
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Layer):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_Constant = init.Uniform()
|
||||
init_Constant(self.row_embed.weight)
|
||||
init_Constant(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
i = paddle.arange(end=w)
|
||||
j = paddle.arange(end=h)
|
||||
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
|
||||
pos = (
|
||||
paddle.concat(
|
||||
[
|
||||
x_emb.unsqueeze(0).tile([h, 1, 1]),
|
||||
y_emb.unsqueeze(1).tile([1, w, 1]),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
.transpose([2, 0, 1])
|
||||
.unsqueeze(0)
|
||||
.tile([x.shape[0], 1, 1, 1])
|
||||
)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
|
||||
N_steps = hidden_dim // 2
|
||||
|
||||
if position_embedding in ("v2", "sine"):
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding}")
|
||||
return position_embedding
|
||||
69
services/paddle_services/doc_dewarp/predict.py
Normal file
69
services/paddle_services/doc_dewarp/predict.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image, to_tensor
|
||||
|
||||
|
||||
def run(args):
|
||||
image_path = args.image
|
||||
model_path = args.model
|
||||
output_path = args.output
|
||||
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
model = GeoTr()
|
||||
model.set_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
img_org = cv2.imread(image_path)
|
||||
img = cv2.resize(img_org, (288, 288))
|
||||
x = to_tensor(img)
|
||||
y = to_tensor(img_org)
|
||||
bm = model(x)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
out_image = to_image(out)
|
||||
cv2.imwrite(output_path, out_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="predict")
|
||||
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"-i",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of image",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
run(args)
|
||||
7
services/paddle_services/doc_dewarp/requirements.txt
Normal file
7
services/paddle_services/doc_dewarp/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
hdf5storage
|
||||
loguru
|
||||
numpy
|
||||
scipy
|
||||
opencv-python
|
||||
matplotlib
|
||||
albumentations
|
||||
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
random.seed(1234567)
|
||||
|
||||
|
||||
def run(args):
|
||||
data_root = os.path.expanduser(args.data_root)
|
||||
ratio = args.train_ratio
|
||||
|
||||
data_path = os.path.join(data_root, "img", "*", "*.png")
|
||||
img_list = glob.glob(data_path, recursive=True)
|
||||
sorted(img_list)
|
||||
random.shuffle(img_list)
|
||||
|
||||
train_size = int(len(img_list) * ratio)
|
||||
|
||||
train_text_path = os.path.join(data_root, "train.txt")
|
||||
with open(train_text_path, "w") as file:
|
||||
for item in img_list[:train_size]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
val_text_path = os.path.join(data_root, "val.txt")
|
||||
with open(val_text_path, "w") as file:
|
||||
for item in img_list[train_size:]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
logger.info(f"TRAIN LABEL: {train_text_path}")
|
||||
logger.info(f"VAL LABEL: {val_text_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="Data path to load data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(args)
|
||||
|
||||
run(args)
|
||||
381
services/paddle_services/doc_dewarp/train.py
Normal file
381
services/paddle_services/doc_dewarp/train.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.optimizer as optim
|
||||
from loguru import logger
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import functional as F
|
||||
from paddle_msssim import ms_ssim, ssim
|
||||
|
||||
from doc3d_dataset import Doc3dDataset
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0]
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
if deterministic:
|
||||
os.environ["FLAGS_cudnn_deterministic"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
|
||||
# i.e. colorstr('blue', 'hello world')
|
||||
|
||||
*args, string = (
|
||||
input if len(input) > 1 else ("blue", "bold", input[0])
|
||||
) # color arguments, string
|
||||
|
||||
colors = {
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",
|
||||
}
|
||||
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
# Print function arguments (optional args dict)
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
# Increment file or directory path,
|
||||
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (
|
||||
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
)
|
||||
|
||||
for n in range(2, 9999):
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def train(args):
|
||||
save_dir = Path(args.save_dir)
|
||||
|
||||
use_vdl = args.use_vdl
|
||||
if use_vdl:
|
||||
from visualdl import LogWriter
|
||||
|
||||
log_dir = save_dir / "vdl"
|
||||
vdl_writer = LogWriter(str(log_dir))
|
||||
|
||||
# Directories
|
||||
weights_dir = save_dir / "weights"
|
||||
weights_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last = weights_dir / "last.ckpt"
|
||||
best = weights_dir / "best.ckpt"
|
||||
|
||||
# Hyperparameters
|
||||
|
||||
# Config
|
||||
init_seeds(args.seed)
|
||||
|
||||
# Train loader
|
||||
train_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="train",
|
||||
is_augment=True,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
# Validation loader
|
||||
val_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="val",
|
||||
is_augment=False,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=args.workers
|
||||
)
|
||||
|
||||
# Model
|
||||
model = GeoTr()
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_graph(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
|
||||
],
|
||||
)
|
||||
|
||||
# Data Parallel Mode
|
||||
if RANK == -1 and paddle.device.cuda.device_count() > 1:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
# Scheduler
|
||||
scheduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=args.lr,
|
||||
total_steps=args.epochs * len(train_loader),
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=args.lr / 2.5e5,
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=scheduler,
|
||||
parameters=model.parameters(),
|
||||
)
|
||||
|
||||
# loss function
|
||||
l1_loss_fn = nn.L1Loss()
|
||||
mse_loss_fn = nn.MSELoss()
|
||||
|
||||
# Resume
|
||||
best_fitness, start_epoch = 0.0, 0
|
||||
if args.resume:
|
||||
ckpt = paddle.load(args.resume)
|
||||
model.set_state_dict(ckpt["model"])
|
||||
optimizer.set_state_dict(ckpt["optimizer"])
|
||||
scheduler.set_state_dict(ckpt["scheduler"])
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
|
||||
# Train
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
model.train()
|
||||
|
||||
for i, (img, target) in enumerate(train_loader):
|
||||
img = paddle.to_tensor(img) # NCHW
|
||||
target = paddle.to_tensor(target) # NHWC
|
||||
|
||||
pred = model(img) # NCHW
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar(
|
||||
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/Learning Rate",
|
||||
float(scheduler.get_lr()),
|
||||
epoch * len(train_loader) + i,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
|
||||
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
|
||||
with paddle.no_grad():
|
||||
avg_ssim = paddle.zeros([])
|
||||
avg_ms_ssim = paddle.zeros([])
|
||||
avg_l1_loss = paddle.zeros([])
|
||||
avg_mse_loss = paddle.zeros([])
|
||||
|
||||
for i, (img, target) in enumerate(val_loader):
|
||||
img = paddle.to_tensor(img)
|
||||
target = paddle.to_tensor(target)
|
||||
|
||||
pred = model(img)
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
# predict image
|
||||
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
|
||||
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
|
||||
|
||||
# calculate ssim
|
||||
ssim_val = ssim(out, out_gt, data_range=1.0)
|
||||
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
# calculate fitness
|
||||
avg_ssim += ssim_val
|
||||
avg_ms_ssim += ms_ssim_val
|
||||
avg_l1_loss += loss
|
||||
avg_mse_loss += mse_loss
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
|
||||
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
|
||||
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
|
||||
)
|
||||
|
||||
if use_vdl and i == 0:
|
||||
img_0 = to_image(out[0])
|
||||
img_gt_0 = to_image(out_gt[0])
|
||||
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
|
||||
|
||||
img_1 = to_image(out[1])
|
||||
img_gt_1 = to_image(out_gt[1])
|
||||
img_gt_1 = img_gt_1.astype("uint8")
|
||||
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
|
||||
|
||||
img_2 = to_image(out[2])
|
||||
img_gt_2 = to_image(out_gt[2])
|
||||
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
|
||||
|
||||
avg_ssim /= len(val_loader)
|
||||
avg_ms_ssim /= len(val_loader)
|
||||
avg_l1_loss /= len(val_loader)
|
||||
avg_mse_loss /= len(val_loader)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
|
||||
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
|
||||
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
|
||||
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
|
||||
|
||||
# Save
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"best_fitness": best_fitness,
|
||||
"epoch": epoch,
|
||||
}
|
||||
|
||||
paddle.save(ckpt, str(last))
|
||||
|
||||
if best_fitness < avg_ssim:
|
||||
best_fitness = avg_ssim
|
||||
paddle.save(ckpt, str(best))
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
print_args(vars(args))
|
||||
|
||||
args.save_dir = str(
|
||||
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
|
||||
)
|
||||
|
||||
train(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Hyperparams")
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="The root path of the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img-size",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=288,
|
||||
help="The size of the input image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=65,
|
||||
help="The number of training epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to previous saved model to restart from",
|
||||
)
|
||||
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
|
||||
parser.add_argument(
|
||||
"--project", default=ROOT / "runs/train", help="save to project/name"
|
||||
)
|
||||
parser.add_argument("--name", default="exp", help="save to project/name")
|
||||
parser.add_argument(
|
||||
"--exist-ok",
|
||||
action="store_true",
|
||||
help="existing project/name ok, do not increment",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
|
||||
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
10
services/paddle_services/doc_dewarp/train.sh
Normal file
10
services/paddle_services/doc_dewarp/train.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export FLAGS_logtostderr=0
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 1e-4 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
67
services/paddle_services/doc_dewarp/utils.py
Normal file
67
services/paddle_services/doc_dewarp/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def to_tensor(img: np.ndarray):
|
||||
"""
|
||||
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): The input image as a numpy array.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output tensor.
|
||||
"""
|
||||
img = img[:, :, ::-1]
|
||||
img = img.astype("float32") / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
out: paddle.Tensor = paddle.to_tensor(img)
|
||||
out = paddle.unsqueeze(out, axis=0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_image(x: paddle.Tensor):
|
||||
"""
|
||||
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
out (numpy.ndarray): The output image as a numpy array.
|
||||
"""
|
||||
out: np.ndarray = x.squeeze().numpy()
|
||||
out = out.transpose(1, 2, 0)
|
||||
out = out * 255.0
|
||||
out = out.astype("uint8")
|
||||
out = out[:, :, ::-1]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def unwarp(img, bm, bm_data_format="NCHW"):
|
||||
"""
|
||||
Unwarp an image using a flow field.
|
||||
|
||||
Args:
|
||||
img (paddle.Tensor): The input image.
|
||||
bm (paddle.Tensor): The flow field.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output image.
|
||||
"""
|
||||
_, _, h, w = img.shape
|
||||
|
||||
if bm_data_format == "NHWC":
|
||||
bm = bm.transpose([0, 3, 1, 2])
|
||||
|
||||
# NCHW
|
||||
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
|
||||
# NHWC
|
||||
bm = bm.transpose([0, 2, 3, 1])
|
||||
# NCHW
|
||||
out = F.grid_sample(img, bm)
|
||||
|
||||
return out
|
||||
152
services/paddle_services/doc_dewarp/weight_init.py
Normal file
152
services/paddle_services/doc_dewarp/weight_init.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.initializer as init
|
||||
from scipy import special
|
||||
|
||||
|
||||
def weight_init_(
|
||||
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
|
||||
):
|
||||
"""
|
||||
In-place params init function.
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
data = np.ones([3, 4], dtype='float32')
|
||||
linear = paddle.nn.Linear(4, 4)
|
||||
input = paddle.to_tensor(data)
|
||||
print(linear.weight)
|
||||
linear(input)
|
||||
|
||||
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
|
||||
print(linear.weight)
|
||||
"""
|
||||
|
||||
if hasattr(layer, "weight") and layer.weight is not None:
|
||||
getattr(init, func)(**kwargs)(layer.weight)
|
||||
if weight_name is not None:
|
||||
# override weight name
|
||||
layer.weight.name = weight_name
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
init.Constant(bias_value)(layer.bias)
|
||||
if bias_name is not None:
|
||||
# override bias name
|
||||
layer.bias.name = bias_name
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
print(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect."
|
||||
)
|
||||
|
||||
with paddle.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
|
||||
tmp = np.random.uniform(
|
||||
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
|
||||
).astype(np.float32)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tmp = special.erfinv(tmp)
|
||||
|
||||
# Transform to proper mean, std
|
||||
tmp *= std * math.sqrt(2.0)
|
||||
tmp += mean
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tmp = np.clip(tmp, a, b)
|
||||
tensor.set_value(paddle.to_tensor(tmp))
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.dim()
|
||||
if dimensions < 2:
|
||||
raise ValueError(
|
||||
"Fan in and fan out can not be computed for tensor "
|
||||
"with fewer than 2 dimensions"
|
||||
)
|
||||
|
||||
num_input_fmaps = tensor.shape[1]
|
||||
num_output_fmaps = tensor.shape[0]
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ["fan_in", "fan_out"]
|
||||
if mode not in valid_modes:
|
||||
raise ValueError(
|
||||
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
|
||||
)
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == "fan_in" else fan_out
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
linear_fns = [
|
||||
"linear",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
]
|
||||
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
|
||||
return 1
|
||||
elif nonlinearity == "tanh":
|
||||
return 5.0 / 3
|
||||
elif nonlinearity == "relu":
|
||||
return math.sqrt(2.0)
|
||||
elif nonlinearity == "leaky_relu":
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif (
|
||||
not isinstance(param, bool)
|
||||
and isinstance(param, int)
|
||||
or isinstance(param, float)
|
||||
):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope**2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
with paddle.no_grad():
|
||||
paddle.nn.initializer.Normal(0, std)(tensor)
|
||||
return tensor
|
||||
36
services/paddle_services/ie_cost.py
Normal file
36
services/paddle_services/ie_cost.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddlenlp import Taskflow
|
||||
|
||||
from __init__ import IE_KEY
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
COST_LIST_SCHEMA = tuple(IE_KEY[key] for key in [
|
||||
'name', 'admission_date', 'discharge_date', 'medical_expenses', 'page'
|
||||
])
|
||||
COST = Taskflow('information_extraction', schema=COST_LIST_SCHEMA, model='uie-x-base',
|
||||
task_path='model/cost_list_model', layout_analysis=False, precision='fp16')
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'], endpoint='cost')
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
layout = request.form.get('layout')
|
||||
return COST({'doc': img_path, 'layout': json.loads(layout)})
|
||||
|
||||
|
||||
@app.route('/text', methods=['POST'])
|
||||
@process_request
|
||||
def text():
|
||||
t = request.form.get('text')
|
||||
return COST(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5004)
|
||||
36
services/paddle_services/ie_discharge.py
Normal file
36
services/paddle_services/ie_discharge.py
Normal file
@@ -0,0 +1,36 @@
|
||||
import json
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddlenlp import Taskflow
|
||||
|
||||
from __init__ import IE_KEY
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
DISCHARGE_RECORD_SCHEMA = tuple(IE_KEY[key] for key in [
|
||||
'hospital', 'department', 'name', 'admission_date', 'discharge_date', 'doctor', 'admission_id', 'age'
|
||||
])
|
||||
DISCHARGE = Taskflow('information_extraction', schema=DISCHARGE_RECORD_SCHEMA, model='uie-x-base',
|
||||
task_path='model/discharge_record_model', layout_analysis=False, precision='fp16')
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'], endpoint='discharge')
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
layout = request.form.get('layout')
|
||||
return DISCHARGE({'doc': img_path, 'layout': json.loads(layout)})
|
||||
|
||||
|
||||
@app.route('/text', methods=['POST'])
|
||||
@process_request
|
||||
def text():
|
||||
t = request.form.get('text')
|
||||
return DISCHARGE(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5003)
|
||||
38
services/paddle_services/ie_settlement.py
Normal file
38
services/paddle_services/ie_settlement.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import json
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddlenlp import Taskflow
|
||||
|
||||
from __init__ import IE_KEY
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
SETTLEMENT_LIST_SCHEMA = tuple(IE_KEY[key] for key in [
|
||||
'name', 'admission_date', 'discharge_date', 'medical_expenses', 'personal_cash_payment',
|
||||
'personal_account_payment', 'personal_funded_amount', 'medical_insurance_type', 'admission_id', 'settlement_id',
|
||||
'uppercase_medical_expenses'
|
||||
])
|
||||
SETTLEMENT_IE = Taskflow('information_extraction', schema=SETTLEMENT_LIST_SCHEMA, model='uie-x-base',
|
||||
task_path='model/settlement_list_model', layout_analysis=False, precision='fp16')
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'], endpoint='settlement')
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
layout = request.form.get('layout')
|
||||
return SETTLEMENT_IE({'doc': img_path, 'layout': json.loads(layout)})
|
||||
|
||||
|
||||
@app.route('/text', methods=['POST'])
|
||||
@process_request
|
||||
def text():
|
||||
t = request.form.get('text')
|
||||
return SETTLEMENT_IE(t)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5002)
|
||||
70
services/paddle_services/log/__init__.py
Normal file
70
services/paddle_services/log/__init__.py
Normal file
@@ -0,0 +1,70 @@
|
||||
import os
|
||||
import socket
|
||||
|
||||
# 获取主机名,方便区分容器
|
||||
HOSTNAME = socket.gethostname()
|
||||
# 检测日志文件的路径是否存在,不存在则创建
|
||||
LOG_PATHS = [
|
||||
f'log/{HOSTNAME}/error',
|
||||
]
|
||||
for path in LOG_PATHS:
|
||||
if not os.path.exists(path):
|
||||
os.makedirs(path)
|
||||
|
||||
# 配置字典
|
||||
LOGGING_CONFIG = {
|
||||
'version': 1, # 必需,指定配置格式的版本
|
||||
'disable_existing_loggers': False, # 是否禁用已经存在的logger实例
|
||||
|
||||
# formatters定义了不同格式的日志样式
|
||||
'formatters': {
|
||||
'standard': {
|
||||
'format': '%(asctime)s [%(levelname)s] %(name)s: %(message)s',
|
||||
'datefmt': '%Y-%m-%d %H:%M:%S',
|
||||
},
|
||||
},
|
||||
|
||||
# handlers定义了不同类型的日志处理器
|
||||
'handlers': {
|
||||
'console': {
|
||||
'class': 'logging.StreamHandler', # 控制台处理器
|
||||
'level': 'DEBUG',
|
||||
'formatter': 'standard',
|
||||
'stream': 'ext://sys.stdout', # 输出到标准输出,默认编码跟随系统,一般为UTF-8
|
||||
},
|
||||
'file': {
|
||||
'class': 'logging.handlers.TimedRotatingFileHandler', # 文件处理器,支持日志滚动
|
||||
'level': 'INFO',
|
||||
'formatter': 'standard',
|
||||
'filename': f'log/{HOSTNAME}/fcb_photo_review.log', # 日志文件路径
|
||||
'when': 'midnight',
|
||||
'interval': 1,
|
||||
'backupCount': 14, # 保留的备份文件数量
|
||||
'encoding': 'utf-8', # 显式指定文件编码为UTF-8以支持中文
|
||||
},
|
||||
'error': {
|
||||
'class': 'logging.handlers.TimedRotatingFileHandler',
|
||||
'level': 'INFO',
|
||||
'formatter': 'standard',
|
||||
'filename': f'log/{HOSTNAME}/error/fcb_photo_review_error.log',
|
||||
'when': 'midnight',
|
||||
'interval': 1,
|
||||
'backupCount': 14,
|
||||
'encoding': 'utf-8',
|
||||
},
|
||||
},
|
||||
|
||||
# loggers定义了日志记录器
|
||||
'loggers': {
|
||||
'': { # 根记录器
|
||||
'handlers': ['console', 'file'], # 关联的处理器
|
||||
'level': 'DEBUG', # 根记录器的级别
|
||||
'propagate': False, # 是否向上级传播日志信息
|
||||
},
|
||||
'error': {
|
||||
'handlers': ['console', 'file', 'error'],
|
||||
'level': 'DEBUG',
|
||||
'propagate': False,
|
||||
},
|
||||
},
|
||||
}
|
||||
1
services/paddle_services/model/cost_list_model/README.md
Normal file
1
services/paddle_services/model/cost_list_model/README.md
Normal file
@@ -0,0 +1 @@
|
||||
住院费用清单信息抽取微调模型存放目录
|
||||
1
services/paddle_services/model/dewarp_model/README.md
Normal file
1
services/paddle_services/model/dewarp_model/README.md
Normal file
@@ -0,0 +1 @@
|
||||
图片扭曲矫正模型存放目录
|
||||
@@ -0,0 +1 @@
|
||||
出院记录信息抽取微调模型存放目录
|
||||
@@ -0,0 +1 @@
|
||||
文档检测模型存放目录
|
||||
@@ -0,0 +1 @@
|
||||
基本医保结算单信息抽取微调模型存放目录
|
||||
@@ -0,0 +1 @@
|
||||
文本分类模型存放目录
|
||||
24
services/paddle_services/ocr.py
Normal file
24
services/paddle_services/ocr.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import logging.config
|
||||
|
||||
from flask import Flask, request
|
||||
from paddleocr import PaddleOCR
|
||||
|
||||
from log import LOGGING_CONFIG
|
||||
from utils import process_request
|
||||
|
||||
app = Flask(__name__)
|
||||
# 如果不希望识别出空格,可以设置use_space_char=False。做此项设置一定要测试,2.7.3版本此项设置有bug,会导致识别失败
|
||||
OCR = PaddleOCR(use_angle_cls=False, show_log=False, det_db_thresh=0.1, det_db_box_thresh=0.3, det_limit_side_len=1248,
|
||||
drop_score=0.3)
|
||||
|
||||
|
||||
@app.route('/', methods=['POST'])
|
||||
@process_request
|
||||
def main():
|
||||
img_path = request.form.get('img_path')
|
||||
return OCR.ocr(img_path, cls=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.config.dictConfig(LOGGING_CONFIG)
|
||||
app.run('0.0.0.0', 5001)
|
||||
1
services/paddle_services/paddle_detection/README.md
Normal file
1
services/paddle_services/paddle_detection/README.md
Normal file
@@ -0,0 +1 @@
|
||||
README_cn.md
|
||||
7
services/paddle_services/paddle_detection/__init__.py
Normal file
7
services/paddle_services/paddle_detection/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
MODEL_DIR = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'model', 'object_det_model')
|
||||
PADDLE_DET = InferenceSession(os.path.join(MODEL_DIR, 'ppyoloe_plus_crn_l_80e_coco_w_nms.onnx'),
|
||||
providers=['CUDAExecutionProvider'], provider_options=[{'device_id': 0}])
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user