移动doc_dewarp
This commit is contained in:
167
services/paddle_services/doc_dewarp/.gitignore
vendored
Normal file
167
services/paddle_services/doc_dewarp/.gitignore
vendored
Normal file
@@ -0,0 +1,167 @@
|
||||
input/
|
||||
output/
|
||||
|
||||
runs/
|
||||
*_dump/
|
||||
*_log/
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
34
services/paddle_services/doc_dewarp/.pre-commit-config.yaml
Normal file
34
services/paddle_services/doc_dewarp/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
||||
repos:
|
||||
# Common hooks
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-added-large-files
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
- id: detect-private-key
|
||||
- id: end-of-file-fixer
|
||||
- id: trailing-whitespace
|
||||
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
||||
rev: v1.5.1
|
||||
hooks:
|
||||
- id: remove-crlf
|
||||
- id: remove-tabs
|
||||
name: Tabs remover (Python)
|
||||
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
args: [--whitespaces-count, '4']
|
||||
# For Python files
|
||||
- repo: https://github.com/psf/black.git
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.11.5
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.0.272
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix, --exit-non-zero-on-fix, --no-cache]
|
||||
398
services/paddle_services/doc_dewarp/GeoTr.py
Normal file
398
services/paddle_services/doc_dewarp/GeoTr.py
Normal file
@@ -0,0 +1,398 @@
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.functional as F
|
||||
|
||||
from .extractor import BasicEncoder
|
||||
from .position_encoding import build_position_encoding
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class attnLayer(nn.Layer):
|
||||
def __init__(
|
||||
self,
|
||||
d_model,
|
||||
nhead=8,
|
||||
dim_feedforward=2048,
|
||||
dropout=0.1,
|
||||
activation="relu",
|
||||
normalize_before=False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
|
||||
self.multihead_attn_list = nn.LayerList(
|
||||
[
|
||||
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
|
||||
for i in range(2)
|
||||
]
|
||||
)
|
||||
|
||||
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||
|
||||
self.norm1 = nn.LayerNorm(d_model)
|
||||
self.norm2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
|
||||
)
|
||||
|
||||
self.norm3 = nn.LayerNorm(d_model)
|
||||
self.dropout1 = nn.Dropout(p=dropout)
|
||||
self.dropout2_list = nn.LayerList(
|
||||
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
|
||||
)
|
||||
self.dropout3 = nn.Dropout(p=dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward_post(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
q = k = self.with_pos_embed(tgt, pos)
|
||||
tgt2 = self.self_attn(
|
||||
q.transpose((1, 0, 2)),
|
||||
k.transpose((1, 0, 2)),
|
||||
value=tgt.transpose((1, 0, 2)),
|
||||
attn_mask=tgt_mask,
|
||||
)
|
||||
tgt2 = tgt2.transpose((1, 0, 2))
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
|
||||
memory_list,
|
||||
self.multihead_attn_list,
|
||||
self.norm2_list,
|
||||
self.dropout2_list,
|
||||
memory_pos,
|
||||
):
|
||||
tgt2 = multihead_attn(
|
||||
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
|
||||
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
|
||||
value=memory.transpose((1, 0, 2)),
|
||||
attn_mask=memory_mask,
|
||||
).transpose((1, 0, 2))
|
||||
|
||||
tgt = tgt + dropout2(tgt2)
|
||||
tgt = norm2(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
tgt = self.norm3(tgt)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward_pre(
|
||||
self,
|
||||
tgt,
|
||||
memory,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
tgt2 = self.norm1(tgt)
|
||||
|
||||
q = k = self.with_pos_embed(tgt2, pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
tgt2 = self.norm2(tgt)
|
||||
|
||||
tgt2 = self.multihead_attn(
|
||||
query=self.with_pos_embed(tgt2, pos),
|
||||
key=self.with_pos_embed(memory, memory_pos),
|
||||
value=memory,
|
||||
attn_mask=memory_mask,
|
||||
)
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
tgt2 = self.norm3(tgt)
|
||||
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout3(tgt2)
|
||||
|
||||
return tgt
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask=None,
|
||||
memory_mask=None,
|
||||
pos=None,
|
||||
memory_pos=None,
|
||||
):
|
||||
if self.normalize_before:
|
||||
return self.forward_pre(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
return self.forward_post(
|
||||
tgt,
|
||||
memory_list,
|
||||
tgt_mask,
|
||||
memory_mask,
|
||||
pos,
|
||||
memory_pos,
|
||||
)
|
||||
|
||||
|
||||
def _get_clones(module, N):
|
||||
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
||||
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransDecoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransDecoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return query_embed
|
||||
|
||||
|
||||
class TransEncoder(nn.Layer):
|
||||
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||
super(TransEncoder, self).__init__()
|
||||
|
||||
attn_layer = attnLayer(hidden_dim)
|
||||
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||
self.position_embedding = build_position_encoding(hidden_dim)
|
||||
|
||||
def forward(self, image: paddle.Tensor):
|
||||
pos = self.position_embedding(
|
||||
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||
)
|
||||
|
||||
b, c, h, w = image.shape
|
||||
|
||||
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||
|
||||
for layer in self.layers:
|
||||
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
|
||||
|
||||
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||
|
||||
return image
|
||||
|
||||
|
||||
class FlowHead(nn.Layer):
|
||||
def __init__(self, input_dim=128, hidden_dim=256):
|
||||
super(FlowHead, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
|
||||
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.relu(self.conv1(x)))
|
||||
|
||||
|
||||
class UpdateBlock(nn.Layer):
|
||||
def __init__(self, hidden_dim: int = 128):
|
||||
super(UpdateBlock, self).__init__()
|
||||
|
||||
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||
|
||||
self.mask = nn.Sequential(
|
||||
nn.Conv2D(hidden_dim, 256, 3, padding=1),
|
||||
nn.ReLU(),
|
||||
nn.Conv2D(256, 64 * 9, 1, padding=0),
|
||||
)
|
||||
|
||||
def forward(self, image, coords):
|
||||
mask = 0.25 * self.mask(image)
|
||||
dflow = self.flow_head(image)
|
||||
coords = coords + dflow
|
||||
return mask, coords
|
||||
|
||||
|
||||
def coords_grid(batch, ht, wd):
|
||||
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
|
||||
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
|
||||
return coords[None].tile([batch, 1, 1, 1])
|
||||
|
||||
|
||||
def upflow8(flow, mode="bilinear"):
|
||||
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
|
||||
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||
|
||||
|
||||
class OverlapPatchEmbed(nn.Layer):
|
||||
"""Image to Patch Embedding"""
|
||||
|
||||
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
|
||||
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
||||
patch_size = (
|
||||
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||
)
|
||||
|
||||
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||
self.num_patches = self.H * self.W
|
||||
|
||||
self.proj = nn.Conv2D(
|
||||
in_chans,
|
||||
embed_dim,
|
||||
patch_size,
|
||||
stride=stride,
|
||||
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
||||
)
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
|
||||
self.apply(self._init_weights)
|
||||
|
||||
def _init_weights(self, m):
|
||||
if isinstance(m, nn.Linear):
|
||||
weight_init_(m, "trunc_normal_", std=0.02)
|
||||
elif isinstance(m, nn.LayerNorm):
|
||||
weight_init_(m, "Constant", value=1.0)
|
||||
elif isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
_, _, H, W = x.shape
|
||||
x = x.flatten(2)
|
||||
|
||||
perm = list(range(x.ndim))
|
||||
perm[1] = 2
|
||||
perm[2] = 1
|
||||
x = x.transpose(perm=perm)
|
||||
|
||||
x = self.norm(x)
|
||||
|
||||
return x, H, W
|
||||
|
||||
|
||||
class GeoTr(nn.Layer):
|
||||
def __init__(self):
|
||||
super(GeoTr, self).__init__()
|
||||
|
||||
self.hidden_dim = hdim = 256
|
||||
|
||||
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
|
||||
|
||||
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.encoder_block:
|
||||
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
|
||||
|
||||
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
|
||||
for i in self.down_layer:
|
||||
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
|
||||
|
||||
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
|
||||
for i in self.decoder_block:
|
||||
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
|
||||
|
||||
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
|
||||
for i in self.up_layer:
|
||||
self.__setattr__(
|
||||
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||
)
|
||||
|
||||
self.query_embed = nn.Embedding(81, self.hidden_dim)
|
||||
|
||||
self.update_block = UpdateBlock(self.hidden_dim)
|
||||
|
||||
def initialize_flow(self, img):
|
||||
N, _, H, W = img.shape
|
||||
coodslar = coords_grid(N, H, W)
|
||||
coords0 = coords_grid(N, H // 8, W // 8)
|
||||
coords1 = coords_grid(N, H // 8, W // 8)
|
||||
return coodslar, coords0, coords1
|
||||
|
||||
def upsample_flow(self, flow, mask):
|
||||
N, _, H, W = flow.shape
|
||||
|
||||
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
|
||||
mask = F.softmax(mask, axis=2)
|
||||
|
||||
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
|
||||
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
|
||||
|
||||
up_flow = paddle.sum(mask * up_flow, axis=2)
|
||||
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
|
||||
|
||||
return up_flow.reshape([N, 2, 8 * H, 8 * W])
|
||||
|
||||
def forward(self, image):
|
||||
fmap = self.fnet(image)
|
||||
fmap = F.relu(fmap)
|
||||
|
||||
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
|
||||
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
|
||||
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
|
||||
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
|
||||
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
|
||||
|
||||
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
|
||||
|
||||
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
|
||||
fmap3du_ = (
|
||||
self.__getattr__(self.up_layer[0])(fmap3d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
|
||||
fmap2du_ = (
|
||||
self.__getattr__(self.up_layer[1])(fmap2d_)
|
||||
.flatten(2)
|
||||
.transpose(perm=[2, 0, 1])
|
||||
)
|
||||
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
|
||||
|
||||
coodslar, coords0, coords1 = self.initialize_flow(image)
|
||||
coords1 = coords1.detach()
|
||||
mask, coords1 = self.update_block(fmap_out, coords1)
|
||||
flow_up = self.upsample_flow(coords1 - coords0, mask)
|
||||
bm_up = coodslar + flow_up
|
||||
|
||||
return bm_up
|
||||
73
services/paddle_services/doc_dewarp/README.md
Normal file
73
services/paddle_services/doc_dewarp/README.md
Normal file
@@ -0,0 +1,73 @@
|
||||
# DocTrPP: DocTr++ in PaddlePaddle
|
||||
|
||||
## Introduction
|
||||
|
||||
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
|
||||
|
||||

|
||||
|
||||
## Requirements
|
||||
|
||||
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
|
||||
|
||||
## Training
|
||||
|
||||
1. Data Preparation
|
||||
|
||||
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
|
||||
|
||||
2. Training
|
||||
|
||||
```shell
|
||||
sh train.sh
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
3. Load Trained Model and Continue Training
|
||||
|
||||
```shell
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 2.5e-5 \
|
||||
--resume "runs/train/DocTr++/weights/last.ckpt" \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
```
|
||||
|
||||
## Test and Inference
|
||||
|
||||
Test the dewarp result on a single image:
|
||||
|
||||
```shell
|
||||
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
|
||||
```
|
||||

|
||||
|
||||
## Export to onnx
|
||||
|
||||
```
|
||||
pip install paddle2onnx
|
||||
|
||||
python export.py -m ./best.ckpt --format onnx
|
||||
```
|
||||
|
||||
## Model Download
|
||||
|
||||
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).
|
||||
4
services/paddle_services/doc_dewarp/__init__.py
Normal file
4
services/paddle_services/doc_dewarp/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from onnxruntime import InferenceSession
|
||||
|
||||
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
|
||||
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])
|
||||
133
services/paddle_services/doc_dewarp/data_visualization.py
Normal file
133
services/paddle_services/doc_dewarp/data_visualization.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.functional as F
|
||||
|
||||
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d/",
|
||||
help="Path to the downloaded dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="output.png",
|
||||
help="Output filename for the image",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
root = os.path.expanduser(args.data_root)
|
||||
folder = args.folder
|
||||
dirname = os.path.join(root, "img", str(folder))
|
||||
|
||||
choices = [f for f in os.listdir(dirname) if "png" in f]
|
||||
fname = random.choice(choices)
|
||||
|
||||
# Read Image
|
||||
img_path = os.path.join(dirname, fname)
|
||||
img = cv2.imread(img_path)
|
||||
|
||||
# Read 3D Coords
|
||||
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
# scale wc
|
||||
# value obtained from the entire dataset
|
||||
xmx, xmn, ymx, ymn, zmx, zmn = (
|
||||
1.2539363,
|
||||
-1.2442188,
|
||||
1.2396319,
|
||||
-1.2289206,
|
||||
0.6436657,
|
||||
-0.67492497,
|
||||
)
|
||||
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
|
||||
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
|
||||
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
|
||||
|
||||
# Read Backward Map
|
||||
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
# Read UV Map
|
||||
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
|
||||
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Depth Map
|
||||
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
|
||||
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
|
||||
# do some clipping and scaling to display it
|
||||
dmap[dmap > 30.0] = 30
|
||||
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
|
||||
|
||||
# Read Normal Map
|
||||
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
|
||||
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read Albedo
|
||||
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
|
||||
alb = cv2.imread(alb_path)
|
||||
|
||||
# Read Checkerboard Image
|
||||
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
|
||||
recon = cv2.imread(recon_path)
|
||||
|
||||
# Display image and GTs
|
||||
|
||||
# use the backward mapping to dewarp the image
|
||||
# scale bm to -1.0 to 1.0
|
||||
bm_ = bm / np.array([448, 448])
|
||||
bm_ = (bm_ - 0.5) * 2
|
||||
bm_ = np.reshape(bm_, (1, 448, 448, 2))
|
||||
bm_ = paddle.to_tensor(bm_, dtype="float32")
|
||||
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
|
||||
img_ = np.expand_dims(img_, 0)
|
||||
img_ = paddle.to_tensor(img_, dtype="float32")
|
||||
uw = F.grid_sample(img_, bm_)
|
||||
uw = uw[0].numpy().transpose((1, 2, 0))
|
||||
|
||||
f, axrr = plt.subplots(2, 5)
|
||||
for ax in axrr:
|
||||
for a in ax:
|
||||
a.set_xticks([])
|
||||
a.set_yticks([])
|
||||
|
||||
axrr[0][0].imshow(img)
|
||||
axrr[0][0].title.set_text("image")
|
||||
axrr[0][1].imshow(wc)
|
||||
axrr[0][1].title.set_text("3D coords")
|
||||
axrr[0][2].imshow(bm[:, :, 0])
|
||||
axrr[0][2].title.set_text("bm 0")
|
||||
axrr[0][3].imshow(bm[:, :, 1])
|
||||
axrr[0][3].title.set_text("bm 1")
|
||||
if uv is None:
|
||||
uv = np.zeros_like(img)
|
||||
axrr[0][4].imshow(uv)
|
||||
axrr[0][4].title.set_text("uv map")
|
||||
axrr[1][0].imshow(dmap)
|
||||
axrr[1][0].title.set_text("depth map")
|
||||
axrr[1][1].imshow(norm)
|
||||
axrr[1][1].title.set_text("normal map")
|
||||
axrr[1][2].imshow(alb)
|
||||
axrr[1][2].title.set_text("albedo")
|
||||
axrr[1][3].imshow(recon)
|
||||
axrr[1][3].title.set_text("checkerboard")
|
||||
axrr[1][4].imshow(uw)
|
||||
axrr[1][4].title.set_text("gt unwarped")
|
||||
plt.tight_layout()
|
||||
plt.savefig(args.output)
|
||||
21
services/paddle_services/doc_dewarp/dewarp.py
Normal file
21
services/paddle_services/doc_dewarp/dewarp.py
Normal file
@@ -0,0 +1,21 @@
|
||||
import cv2
|
||||
import numpy as np
|
||||
import paddle
|
||||
|
||||
from . import DOC_TR
|
||||
from .utils import to_tensor, to_image
|
||||
|
||||
|
||||
def dewarp_image(image):
|
||||
img = cv2.resize(image, (288, 288)).astype(np.float32)
|
||||
y = to_tensor(image)
|
||||
|
||||
img = np.transpose(img, (2, 0, 1))
|
||||
bm = DOC_TR.run(None, {"image": img[None,]})[0]
|
||||
bm = paddle.to_tensor(bm)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
return to_image(out)
|
||||
161
services/paddle_services/doc_dewarp/doc/download_dataset.sh
Normal file
161
services/paddle_services/doc_dewarp/doc/download_dataset.sh
Normal file
@@ -0,0 +1,161 @@
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
|
||||
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
|
||||
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
129
services/paddle_services/doc_dewarp/doc3d_dataset.py
Normal file
129
services/paddle_services/doc_dewarp/doc3d_dataset.py
Normal file
@@ -0,0 +1,129 @@
|
||||
import collections
|
||||
import os
|
||||
import random
|
||||
|
||||
import albumentations as A
|
||||
import cv2
|
||||
import hdf5storage as h5
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle import io
|
||||
|
||||
# Set random seed
|
||||
random.seed(12345678)
|
||||
|
||||
|
||||
class Doc3dDataset(io.Dataset):
|
||||
def __init__(self, root, split="train", is_augment=False, image_size=512):
|
||||
self.root = os.path.expanduser(root)
|
||||
|
||||
self.split = split
|
||||
self.is_augment = is_augment
|
||||
|
||||
self.files = collections.defaultdict(list)
|
||||
|
||||
self.image_size = (
|
||||
image_size if isinstance(image_size, tuple) else (image_size, image_size)
|
||||
)
|
||||
|
||||
# Augmentation
|
||||
self.augmentation = A.Compose(
|
||||
[
|
||||
A.ColorJitter(),
|
||||
]
|
||||
)
|
||||
|
||||
for split in ["train", "val"]:
|
||||
path = os.path.join(self.root, split + ".txt")
|
||||
file_list = []
|
||||
with open(path, "r") as file:
|
||||
file_list = [file_id.rstrip() for file_id in file.readlines()]
|
||||
self.files[split] = file_list
|
||||
|
||||
def __len__(self):
|
||||
return len(self.files[self.split])
|
||||
|
||||
def __getitem__(self, index):
|
||||
image_name = self.files[self.split][index]
|
||||
|
||||
# Read image
|
||||
image_path = os.path.join(self.root, "img", image_name + ".png")
|
||||
image = cv2.imread(image_path)
|
||||
|
||||
# Read 3D Coordinates
|
||||
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
|
||||
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||
|
||||
# Read backward map
|
||||
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
|
||||
bm = h5.loadmat(bm_path)["bm"]
|
||||
|
||||
image, bm = self.transform(wc, bm, image)
|
||||
|
||||
return image, bm
|
||||
|
||||
def tight_crop(self, wc: np.ndarray):
|
||||
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
|
||||
np.uint8
|
||||
)
|
||||
mask_size = mask.shape
|
||||
[y, x] = mask.nonzero()
|
||||
min_x = min(x)
|
||||
max_x = max(x)
|
||||
min_y = min(y)
|
||||
max_y = max(y)
|
||||
|
||||
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
|
||||
s = 10
|
||||
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
|
||||
|
||||
cx1 = random.randint(0, 2 * s)
|
||||
cx2 = random.randint(0, 2 * s) + 1
|
||||
cy1 = random.randint(0, 2 * s)
|
||||
cy2 = random.randint(0, 2 * s) + 1
|
||||
|
||||
wc = wc[cy1:-cy2, cx1:-cx2, :]
|
||||
|
||||
top: int = min_y - s + cy1
|
||||
bottom: int = mask_size[0] - max_y - s + cy2
|
||||
left: int = min_x - s + cx1
|
||||
right: int = mask_size[1] - max_x - s + cx2
|
||||
|
||||
top = np.clip(top, 0, mask_size[0])
|
||||
bottom = np.clip(bottom, 1, mask_size[0] - 1)
|
||||
left = np.clip(left, 0, mask_size[1])
|
||||
right = np.clip(right, 1, mask_size[1] - 1)
|
||||
|
||||
return wc, top, bottom, left, right
|
||||
|
||||
def transform(self, wc, bm, img):
|
||||
wc, top, bottom, left, right = self.tight_crop(wc)
|
||||
|
||||
img = img[top:-bottom, left:-right, :]
|
||||
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
if self.is_augment:
|
||||
img = self.augmentation(image=img)["image"]
|
||||
|
||||
# resize image
|
||||
img = cv2.resize(img, self.image_size)
|
||||
img = img.astype(np.float32) / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
|
||||
# resize bm
|
||||
bm = bm.astype(np.float32)
|
||||
bm[:, :, 1] = bm[:, :, 1] - top
|
||||
bm[:, :, 0] = bm[:, :, 0] - left
|
||||
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
|
||||
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
|
||||
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
|
||||
bm0 = bm0 * self.image_size[0]
|
||||
bm1 = bm1 * self.image_size[1]
|
||||
|
||||
bm = np.stack([bm0, bm1], axis=-1)
|
||||
|
||||
img = paddle.to_tensor(img).astype(dtype="float32")
|
||||
bm = paddle.to_tensor(bm).astype(dtype="float32")
|
||||
|
||||
return img, bm
|
||||
66
services/paddle_services/doc_dewarp/export.py
Normal file
66
services/paddle_services/doc_dewarp/export.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def export(args):
|
||||
model_path = args.model
|
||||
imgsz = args.imgsz
|
||||
format = args.format
|
||||
|
||||
model = GeoTr()
|
||||
checkpoint = paddle.load(model_path)
|
||||
model.set_state_dict(checkpoint["model"])
|
||||
model.eval()
|
||||
|
||||
dirname = os.path.dirname(model_path)
|
||||
if format == "static" or format == "onnx":
|
||||
model = paddle.jit.to_static(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
|
||||
],
|
||||
full_graph=True,
|
||||
)
|
||||
paddle.jit.save(model, os.path.join(dirname, "model"))
|
||||
|
||||
if format == "onnx":
|
||||
onnx_path = os.path.join(dirname, "model.onnx")
|
||||
os.system(
|
||||
f"paddle2onnx --model_dir {dirname}"
|
||||
" --model_filename model.pdmodel"
|
||||
" --params_filename model.pdiparams"
|
||||
f" --save_file {onnx_path}"
|
||||
" --opset_version 11"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="export model")
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--imgsz", type=int, default=288, help="The size of input image"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--format",
|
||||
type=str,
|
||||
default="static",
|
||||
help="The format of exported model, which can be static or onnx",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
export(args)
|
||||
110
services/paddle_services/doc_dewarp/extractor.py
Normal file
110
services/paddle_services/doc_dewarp/extractor.py
Normal file
@@ -0,0 +1,110 @@
|
||||
import paddle.nn as nn
|
||||
|
||||
from .weight_init import weight_init_
|
||||
|
||||
|
||||
class ResidualBlock(nn.Layer):
|
||||
"""Residual Block with custom normalization."""
|
||||
|
||||
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||
super(ResidualBlock, self).__init__()
|
||||
|
||||
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
|
||||
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
if norm_fn == "group":
|
||||
num_groups = planes // 8
|
||||
self.norm1 = nn.GroupNorm(num_groups, planes)
|
||||
self.norm2 = nn.GroupNorm(num_groups, planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.GroupNorm(num_groups, planes)
|
||||
elif norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(planes)
|
||||
self.norm2 = nn.BatchNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.BatchNorm2D(planes)
|
||||
elif norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(planes)
|
||||
self.norm2 = nn.InstanceNorm2D(planes)
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.InstanceNorm2D(planes)
|
||||
elif norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
self.norm2 = nn.Sequential()
|
||||
if not stride == 1:
|
||||
self.norm3 = nn.Sequential()
|
||||
|
||||
if stride == 1:
|
||||
self.downsample = None
|
||||
else:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
y = x
|
||||
y = self.relu(self.norm1(self.conv1(y)))
|
||||
y = self.relu(self.norm2(self.conv2(y)))
|
||||
|
||||
if self.downsample is not None:
|
||||
x = self.downsample(x)
|
||||
|
||||
return self.relu(x + y)
|
||||
|
||||
|
||||
class BasicEncoder(nn.Layer):
|
||||
"""Basic Encoder with custom normalization."""
|
||||
|
||||
def __init__(self, output_dim=128, norm_fn="batch"):
|
||||
super(BasicEncoder, self).__init__()
|
||||
|
||||
self.norm_fn = norm_fn
|
||||
|
||||
if self.norm_fn == "group":
|
||||
self.norm1 = nn.GroupNorm(8, 64)
|
||||
elif self.norm_fn == "batch":
|
||||
self.norm1 = nn.BatchNorm2D(64)
|
||||
elif self.norm_fn == "instance":
|
||||
self.norm1 = nn.InstanceNorm2D(64)
|
||||
elif self.norm_fn == "none":
|
||||
self.norm1 = nn.Sequential()
|
||||
|
||||
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
|
||||
self.relu1 = nn.ReLU()
|
||||
|
||||
self.in_planes = 64
|
||||
self.layer1 = self._make_layer(64, stride=1)
|
||||
self.layer2 = self._make_layer(128, stride=2)
|
||||
self.layer3 = self._make_layer(192, stride=2)
|
||||
|
||||
self.conv2 = nn.Conv2D(192, output_dim, 1)
|
||||
|
||||
for m in self.sublayers():
|
||||
if isinstance(m, nn.Conv2D):
|
||||
weight_init_(
|
||||
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||
)
|
||||
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
|
||||
weight_init_(m, "Constant", value=1, bias_value=0.0)
|
||||
|
||||
def _make_layer(self, dim, stride=1):
|
||||
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||
layers = layer1, layer2
|
||||
|
||||
self.in_planes = dim
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.norm1(x)
|
||||
x = self.relu1(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
|
||||
x = self.conv2(x)
|
||||
|
||||
return x
|
||||
48
services/paddle_services/doc_dewarp/plots.py
Normal file
48
services/paddle_services/doc_dewarp/plots.py
Normal file
@@ -0,0 +1,48 @@
|
||||
import copy
|
||||
import os
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import paddle.optimizer as optim
|
||||
|
||||
from GeoTr import GeoTr
|
||||
|
||||
|
||||
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
|
||||
"""
|
||||
Plot the learning rate scheduler
|
||||
"""
|
||||
|
||||
optimizer = copy.copy(optimizer)
|
||||
scheduler = copy.copy(scheduler)
|
||||
|
||||
lr = []
|
||||
for _ in range(epochs):
|
||||
for _ in range(30):
|
||||
lr.append(scheduler.get_lr())
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
|
||||
epoch = [float(i) / 30.0 for i in range(len(lr))]
|
||||
|
||||
plt.figure()
|
||||
plt.plot(epoch, lr, ".-", label="Learning Rate")
|
||||
plt.xlabel("epoch")
|
||||
plt.ylabel("Learning Rate")
|
||||
plt.title("Learning Rate Scheduler")
|
||||
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
|
||||
plt.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
model = GeoTr()
|
||||
|
||||
schaduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=1e-4,
|
||||
total_steps=1950,
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=1e-4 / 2.5e5,
|
||||
)
|
||||
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
|
||||
plot_lr_scheduler(
|
||||
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
|
||||
)
|
||||
124
services/paddle_services/doc_dewarp/position_encoding.py
Normal file
124
services/paddle_services/doc_dewarp/position_encoding.py
Normal file
@@ -0,0 +1,124 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.nn.initializer as init
|
||||
|
||||
|
||||
class NestedTensor(object):
|
||||
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
|
||||
self.tensors = tensors
|
||||
self.mask = mask
|
||||
|
||||
def decompose(self):
|
||||
return self.tensors, self.mask
|
||||
|
||||
def __repr__(self):
|
||||
return str(self.tensors)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Layer):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||
):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, mask):
|
||||
assert mask is not None
|
||||
|
||||
y_embed = mask.cumsum(axis=1, dtype="float32")
|
||||
x_embed = mask.cumsum(axis=2, dtype="float32")
|
||||
|
||||
if self.normalize:
|
||||
eps = 1e-06
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
|
||||
pos_x = paddle.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
pos_y = paddle.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
|
||||
).flatten(3)
|
||||
|
||||
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
class PositionEmbeddingLearned(nn.Layer):
|
||||
"""
|
||||
Absolute pos embedding, learned.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=256):
|
||||
super().__init__()
|
||||
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
init_Constant = init.Uniform()
|
||||
init_Constant(self.row_embed.weight)
|
||||
init_Constant(self.col_embed.weight)
|
||||
|
||||
def forward(self, tensor_list: NestedTensor):
|
||||
x = tensor_list.tensors
|
||||
|
||||
h, w = x.shape[-2:]
|
||||
|
||||
i = paddle.arange(end=w)
|
||||
j = paddle.arange(end=h)
|
||||
|
||||
x_emb = self.col_embed(i)
|
||||
y_emb = self.row_embed(j)
|
||||
|
||||
pos = (
|
||||
paddle.concat(
|
||||
[
|
||||
x_emb.unsqueeze(0).tile([h, 1, 1]),
|
||||
y_emb.unsqueeze(1).tile([1, w, 1]),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
.transpose([2, 0, 1])
|
||||
.unsqueeze(0)
|
||||
.tile([x.shape[0], 1, 1, 1])
|
||||
)
|
||||
|
||||
return pos
|
||||
|
||||
|
||||
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
|
||||
N_steps = hidden_dim // 2
|
||||
|
||||
if position_embedding in ("v2", "sine"):
|
||||
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||
elif position_embedding in ("v3", "learned"):
|
||||
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||
else:
|
||||
raise ValueError(f"not supported {position_embedding}")
|
||||
return position_embedding
|
||||
69
services/paddle_services/doc_dewarp/predict.py
Normal file
69
services/paddle_services/doc_dewarp/predict.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import argparse
|
||||
|
||||
import cv2
|
||||
import paddle
|
||||
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image, to_tensor
|
||||
|
||||
|
||||
def run(args):
|
||||
image_path = args.image
|
||||
model_path = args.model
|
||||
output_path = args.output
|
||||
|
||||
checkpoint = paddle.load(model_path)
|
||||
state_dict = checkpoint["model"]
|
||||
model = GeoTr()
|
||||
model.set_state_dict(state_dict)
|
||||
model.eval()
|
||||
|
||||
img_org = cv2.imread(image_path)
|
||||
img = cv2.resize(img_org, (288, 288))
|
||||
x = to_tensor(img)
|
||||
y = to_tensor(img_org)
|
||||
bm = model(x)
|
||||
bm = paddle.nn.functional.interpolate(
|
||||
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||
)
|
||||
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||
out_image = to_image(out)
|
||||
cv2.imwrite(output_path, out_image)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="predict")
|
||||
|
||||
parser.add_argument(
|
||||
"--image",
|
||||
"-i",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of image",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
"-m",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
"-o",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="",
|
||||
help="The path of output",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(args)
|
||||
|
||||
run(args)
|
||||
7
services/paddle_services/doc_dewarp/requirements.txt
Normal file
7
services/paddle_services/doc_dewarp/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
hdf5storage
|
||||
loguru
|
||||
numpy
|
||||
scipy
|
||||
opencv-python
|
||||
matplotlib
|
||||
albumentations
|
||||
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
57
services/paddle_services/doc_dewarp/split_dataset.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import argparse
|
||||
import glob
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
from loguru import logger
|
||||
|
||||
random.seed(1234567)
|
||||
|
||||
|
||||
def run(args):
|
||||
data_root = os.path.expanduser(args.data_root)
|
||||
ratio = args.train_ratio
|
||||
|
||||
data_path = os.path.join(data_root, "img", "*", "*.png")
|
||||
img_list = glob.glob(data_path, recursive=True)
|
||||
sorted(img_list)
|
||||
random.shuffle(img_list)
|
||||
|
||||
train_size = int(len(img_list) * ratio)
|
||||
|
||||
train_text_path = os.path.join(data_root, "train.txt")
|
||||
with open(train_text_path, "w") as file:
|
||||
for item in img_list[:train_size]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
val_text_path = os.path.join(data_root, "val.txt")
|
||||
with open(val_text_path, "w") as file:
|
||||
for item in img_list[train_size:]:
|
||||
parts = Path(item).parts
|
||||
item = os.path.join(parts[-2], parts[-1])
|
||||
file.write("%s\n" % item.split(".png")[0])
|
||||
|
||||
logger.info(f"TRAIN LABEL: {train_text_path}")
|
||||
logger.info(f"VAL LABEL: {val_text_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--data_root",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="Data path to load data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(args)
|
||||
|
||||
run(args)
|
||||
381
services/paddle_services/doc_dewarp/train.py
Normal file
381
services/paddle_services/doc_dewarp/train.py
Normal file
@@ -0,0 +1,381 @@
|
||||
import argparse
|
||||
import inspect
|
||||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
import paddle.optimizer as optim
|
||||
from loguru import logger
|
||||
from paddle.io import DataLoader
|
||||
from paddle.nn import functional as F
|
||||
from paddle_msssim import ms_ssim, ssim
|
||||
|
||||
from doc3d_dataset import Doc3dDataset
|
||||
from GeoTr import GeoTr
|
||||
from utils import to_image
|
||||
|
||||
FILE = Path(__file__).resolve()
|
||||
ROOT = FILE.parents[0]
|
||||
RANK = int(os.getenv("RANK", -1))
|
||||
|
||||
|
||||
def init_seeds(seed=0, deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
paddle.seed(seed)
|
||||
if deterministic:
|
||||
os.environ["FLAGS_cudnn_deterministic"] = "1"
|
||||
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||
|
||||
|
||||
def colorstr(*input):
|
||||
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
|
||||
# i.e. colorstr('blue', 'hello world')
|
||||
|
||||
*args, string = (
|
||||
input if len(input) > 1 else ("blue", "bold", input[0])
|
||||
) # color arguments, string
|
||||
|
||||
colors = {
|
||||
"black": "\033[30m", # basic colors
|
||||
"red": "\033[31m",
|
||||
"green": "\033[32m",
|
||||
"yellow": "\033[33m",
|
||||
"blue": "\033[34m",
|
||||
"magenta": "\033[35m",
|
||||
"cyan": "\033[36m",
|
||||
"white": "\033[37m",
|
||||
"bright_black": "\033[90m", # bright colors
|
||||
"bright_red": "\033[91m",
|
||||
"bright_green": "\033[92m",
|
||||
"bright_yellow": "\033[93m",
|
||||
"bright_blue": "\033[94m",
|
||||
"bright_magenta": "\033[95m",
|
||||
"bright_cyan": "\033[96m",
|
||||
"bright_white": "\033[97m",
|
||||
"end": "\033[0m", # misc
|
||||
"bold": "\033[1m",
|
||||
"underline": "\033[4m",
|
||||
}
|
||||
|
||||
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||
|
||||
|
||||
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||
# Print function arguments (optional args dict)
|
||||
x = inspect.currentframe().f_back # previous frame
|
||||
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||
if args is None: # get args automatically
|
||||
args, _, _, frm = inspect.getargvalues(x)
|
||||
args = {k: v for k, v in frm.items() if k in args}
|
||||
try:
|
||||
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||
except ValueError:
|
||||
file = Path(file).stem
|
||||
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
|
||||
|
||||
|
||||
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||
# Increment file or directory path,
|
||||
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||
path = Path(path) # os-agnostic
|
||||
if path.exists() and not exist_ok:
|
||||
path, suffix = (
|
||||
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||
)
|
||||
|
||||
for n in range(2, 9999):
|
||||
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||
if not os.path.exists(p):
|
||||
break
|
||||
path = Path(p)
|
||||
|
||||
if mkdir:
|
||||
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||
|
||||
return path
|
||||
|
||||
|
||||
def train(args):
|
||||
save_dir = Path(args.save_dir)
|
||||
|
||||
use_vdl = args.use_vdl
|
||||
if use_vdl:
|
||||
from visualdl import LogWriter
|
||||
|
||||
log_dir = save_dir / "vdl"
|
||||
vdl_writer = LogWriter(str(log_dir))
|
||||
|
||||
# Directories
|
||||
weights_dir = save_dir / "weights"
|
||||
weights_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
last = weights_dir / "last.ckpt"
|
||||
best = weights_dir / "best.ckpt"
|
||||
|
||||
# Hyperparameters
|
||||
|
||||
# Config
|
||||
init_seeds(args.seed)
|
||||
|
||||
# Train loader
|
||||
train_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="train",
|
||||
is_augment=True,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.batch_size,
|
||||
shuffle=True,
|
||||
num_workers=args.workers,
|
||||
)
|
||||
|
||||
# Validation loader
|
||||
val_dataset = Doc3dDataset(
|
||||
args.data_root,
|
||||
split="val",
|
||||
is_augment=False,
|
||||
image_size=args.img_size,
|
||||
)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset, batch_size=args.batch_size, num_workers=args.workers
|
||||
)
|
||||
|
||||
# Model
|
||||
model = GeoTr()
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_graph(
|
||||
model,
|
||||
input_spec=[
|
||||
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
|
||||
],
|
||||
)
|
||||
|
||||
# Data Parallel Mode
|
||||
if RANK == -1 and paddle.device.cuda.device_count() > 1:
|
||||
model = paddle.DataParallel(model)
|
||||
|
||||
# Scheduler
|
||||
scheduler = optim.lr.OneCycleLR(
|
||||
max_learning_rate=args.lr,
|
||||
total_steps=args.epochs * len(train_loader),
|
||||
phase_pct=0.1,
|
||||
end_learning_rate=args.lr / 2.5e5,
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
optimizer = optim.AdamW(
|
||||
learning_rate=scheduler,
|
||||
parameters=model.parameters(),
|
||||
)
|
||||
|
||||
# loss function
|
||||
l1_loss_fn = nn.L1Loss()
|
||||
mse_loss_fn = nn.MSELoss()
|
||||
|
||||
# Resume
|
||||
best_fitness, start_epoch = 0.0, 0
|
||||
if args.resume:
|
||||
ckpt = paddle.load(args.resume)
|
||||
model.set_state_dict(ckpt["model"])
|
||||
optimizer.set_state_dict(ckpt["optimizer"])
|
||||
scheduler.set_state_dict(ckpt["scheduler"])
|
||||
best_fitness = ckpt["best_fitness"]
|
||||
start_epoch = ckpt["epoch"] + 1
|
||||
|
||||
# Train
|
||||
for epoch in range(start_epoch, args.epochs):
|
||||
model.train()
|
||||
|
||||
for i, (img, target) in enumerate(train_loader):
|
||||
img = paddle.to_tensor(img) # NCHW
|
||||
target = paddle.to_tensor(target) # NHWC
|
||||
|
||||
pred = model(img) # NCHW
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar(
|
||||
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
|
||||
)
|
||||
vdl_writer.add_scalar(
|
||||
"Train/Learning Rate",
|
||||
float(scheduler.get_lr()),
|
||||
epoch * len(train_loader) + i,
|
||||
)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
scheduler.step()
|
||||
optimizer.clear_grad()
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
|
||||
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
|
||||
)
|
||||
|
||||
# Validation
|
||||
model.eval()
|
||||
|
||||
with paddle.no_grad():
|
||||
avg_ssim = paddle.zeros([])
|
||||
avg_ms_ssim = paddle.zeros([])
|
||||
avg_l1_loss = paddle.zeros([])
|
||||
avg_mse_loss = paddle.zeros([])
|
||||
|
||||
for i, (img, target) in enumerate(val_loader):
|
||||
img = paddle.to_tensor(img)
|
||||
target = paddle.to_tensor(target)
|
||||
|
||||
pred = model(img)
|
||||
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||
|
||||
# predict image
|
||||
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
|
||||
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
|
||||
|
||||
# calculate ssim
|
||||
ssim_val = ssim(out, out_gt, data_range=1.0)
|
||||
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
|
||||
|
||||
loss = l1_loss_fn(pred_nhwc, target)
|
||||
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||
|
||||
# calculate fitness
|
||||
avg_ssim += ssim_val
|
||||
avg_ms_ssim += ms_ssim_val
|
||||
avg_l1_loss += loss
|
||||
avg_mse_loss += mse_loss
|
||||
|
||||
if i % 10 == 0:
|
||||
logger.info(
|
||||
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
|
||||
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
|
||||
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
|
||||
)
|
||||
|
||||
if use_vdl and i == 0:
|
||||
img_0 = to_image(out[0])
|
||||
img_gt_0 = to_image(out_gt[0])
|
||||
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
|
||||
|
||||
img_1 = to_image(out[1])
|
||||
img_gt_1 = to_image(out_gt[1])
|
||||
img_gt_1 = img_gt_1.astype("uint8")
|
||||
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
|
||||
|
||||
img_2 = to_image(out[2])
|
||||
img_gt_2 = to_image(out_gt[2])
|
||||
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
|
||||
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
|
||||
|
||||
avg_ssim /= len(val_loader)
|
||||
avg_ms_ssim /= len(val_loader)
|
||||
avg_l1_loss /= len(val_loader)
|
||||
avg_mse_loss /= len(val_loader)
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
|
||||
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
|
||||
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
|
||||
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
|
||||
|
||||
# Save
|
||||
ckpt = {
|
||||
"model": model.state_dict(),
|
||||
"optimizer": optimizer.state_dict(),
|
||||
"scheduler": scheduler.state_dict(),
|
||||
"best_fitness": best_fitness,
|
||||
"epoch": epoch,
|
||||
}
|
||||
|
||||
paddle.save(ckpt, str(last))
|
||||
|
||||
if best_fitness < avg_ssim:
|
||||
best_fitness = avg_ssim
|
||||
paddle.save(ckpt, str(best))
|
||||
|
||||
if use_vdl:
|
||||
vdl_writer.close()
|
||||
|
||||
|
||||
def main(args):
|
||||
print_args(vars(args))
|
||||
|
||||
args.save_dir = str(
|
||||
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
|
||||
)
|
||||
|
||||
train(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Hyperparams")
|
||||
parser.add_argument(
|
||||
"--data-root",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default="~/datasets/doc3d",
|
||||
help="The root path of the dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--img-size",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=288,
|
||||
help="The size of the input image",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
nargs="?",
|
||||
type=int,
|
||||
default=65,
|
||||
help="The number of training epochs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume",
|
||||
nargs="?",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to previous saved model to restart from",
|
||||
)
|
||||
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
|
||||
parser.add_argument(
|
||||
"--project", default=ROOT / "runs/train", help="save to project/name"
|
||||
)
|
||||
parser.add_argument("--name", default="exp", help="save to project/name")
|
||||
parser.add_argument(
|
||||
"--exist-ok",
|
||||
action="store_true",
|
||||
help="existing project/name ok, do not increment",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
|
||||
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
10
services/paddle_services/doc_dewarp/train.sh
Normal file
10
services/paddle_services/doc_dewarp/train.sh
Normal file
@@ -0,0 +1,10 @@
|
||||
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||
export FLAGS_logtostderr=0
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
|
||||
python train.py --img-size 288 \
|
||||
--name "DocTr++" \
|
||||
--batch-size 12 \
|
||||
--lr 1e-4 \
|
||||
--exist-ok \
|
||||
--use-vdl
|
||||
67
services/paddle_services/doc_dewarp/utils.py
Normal file
67
services/paddle_services/doc_dewarp/utils.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import numpy as np
|
||||
import paddle
|
||||
from paddle.nn import functional as F
|
||||
|
||||
|
||||
def to_tensor(img: np.ndarray):
|
||||
"""
|
||||
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
|
||||
|
||||
Args:
|
||||
img (numpy.ndarray): The input image as a numpy array.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output tensor.
|
||||
"""
|
||||
img = img[:, :, ::-1]
|
||||
img = img.astype("float32") / 255.0
|
||||
img = img.transpose(2, 0, 1)
|
||||
out: paddle.Tensor = paddle.to_tensor(img)
|
||||
out = paddle.unsqueeze(out, axis=0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def to_image(x: paddle.Tensor):
|
||||
"""
|
||||
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
|
||||
|
||||
Args:
|
||||
x (paddle.Tensor): The input tensor.
|
||||
|
||||
Returns:
|
||||
out (numpy.ndarray): The output image as a numpy array.
|
||||
"""
|
||||
out: np.ndarray = x.squeeze().numpy()
|
||||
out = out.transpose(1, 2, 0)
|
||||
out = out * 255.0
|
||||
out = out.astype("uint8")
|
||||
out = out[:, :, ::-1]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def unwarp(img, bm, bm_data_format="NCHW"):
|
||||
"""
|
||||
Unwarp an image using a flow field.
|
||||
|
||||
Args:
|
||||
img (paddle.Tensor): The input image.
|
||||
bm (paddle.Tensor): The flow field.
|
||||
|
||||
Returns:
|
||||
out (paddle.Tensor): The output image.
|
||||
"""
|
||||
_, _, h, w = img.shape
|
||||
|
||||
if bm_data_format == "NHWC":
|
||||
bm = bm.transpose([0, 3, 1, 2])
|
||||
|
||||
# NCHW
|
||||
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
|
||||
# NHWC
|
||||
bm = bm.transpose([0, 2, 3, 1])
|
||||
# NCHW
|
||||
out = F.grid_sample(img, bm)
|
||||
|
||||
return out
|
||||
152
services/paddle_services/doc_dewarp/weight_init.py
Normal file
152
services/paddle_services/doc_dewarp/weight_init.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import paddle
|
||||
import paddle.nn.initializer as init
|
||||
from scipy import special
|
||||
|
||||
|
||||
def weight_init_(
|
||||
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
|
||||
):
|
||||
"""
|
||||
In-place params init function.
|
||||
Usage:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
import numpy as np
|
||||
|
||||
data = np.ones([3, 4], dtype='float32')
|
||||
linear = paddle.nn.Linear(4, 4)
|
||||
input = paddle.to_tensor(data)
|
||||
print(linear.weight)
|
||||
linear(input)
|
||||
|
||||
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
|
||||
print(linear.weight)
|
||||
"""
|
||||
|
||||
if hasattr(layer, "weight") and layer.weight is not None:
|
||||
getattr(init, func)(**kwargs)(layer.weight)
|
||||
if weight_name is not None:
|
||||
# override weight name
|
||||
layer.weight.name = weight_name
|
||||
|
||||
if hasattr(layer, "bias") and layer.bias is not None:
|
||||
init.Constant(bias_value)(layer.bias)
|
||||
if bias_name is not None:
|
||||
# override bias name
|
||||
layer.bias.name = bias_name
|
||||
|
||||
|
||||
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||
def norm_cdf(x):
|
||||
# Computes standard normal cumulative distribution function
|
||||
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||
|
||||
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||
print(
|
||||
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||
"The distribution of values may be incorrect."
|
||||
)
|
||||
|
||||
with paddle.no_grad():
|
||||
# Values are generated by using a truncated uniform distribution and
|
||||
# then using the inverse CDF for the normal distribution.
|
||||
# Get upper and lower cdf values
|
||||
lower = norm_cdf((a - mean) / std)
|
||||
upper = norm_cdf((b - mean) / std)
|
||||
|
||||
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
|
||||
tmp = np.random.uniform(
|
||||
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
|
||||
).astype(np.float32)
|
||||
|
||||
# Use inverse cdf transform for normal distribution to get truncated
|
||||
# standard normal
|
||||
tmp = special.erfinv(tmp)
|
||||
|
||||
# Transform to proper mean, std
|
||||
tmp *= std * math.sqrt(2.0)
|
||||
tmp += mean
|
||||
|
||||
# Clamp to ensure it's in the proper range
|
||||
tmp = np.clip(tmp, a, b)
|
||||
tensor.set_value(paddle.to_tensor(tmp))
|
||||
|
||||
return tensor
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
dimensions = tensor.dim()
|
||||
if dimensions < 2:
|
||||
raise ValueError(
|
||||
"Fan in and fan out can not be computed for tensor "
|
||||
"with fewer than 2 dimensions"
|
||||
)
|
||||
|
||||
num_input_fmaps = tensor.shape[1]
|
||||
num_output_fmaps = tensor.shape[0]
|
||||
receptive_field_size = 1
|
||||
if tensor.dim() > 2:
|
||||
receptive_field_size = tensor[0][0].numel()
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||
|
||||
|
||||
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ["fan_in", "fan_out"]
|
||||
if mode not in valid_modes:
|
||||
raise ValueError(
|
||||
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
|
||||
)
|
||||
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == "fan_in" else fan_out
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
linear_fns = [
|
||||
"linear",
|
||||
"conv1d",
|
||||
"conv2d",
|
||||
"conv3d",
|
||||
"conv_transpose1d",
|
||||
"conv_transpose2d",
|
||||
"conv_transpose3d",
|
||||
]
|
||||
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
|
||||
return 1
|
||||
elif nonlinearity == "tanh":
|
||||
return 5.0 / 3
|
||||
elif nonlinearity == "relu":
|
||||
return math.sqrt(2.0)
|
||||
elif nonlinearity == "leaky_relu":
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif (
|
||||
not isinstance(param, bool)
|
||||
and isinstance(param, int)
|
||||
or isinstance(param, float)
|
||||
):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
return math.sqrt(2.0 / (1 + negative_slope**2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
|
||||
fan = _calculate_correct_fan(tensor, mode)
|
||||
gain = calculate_gain(nonlinearity, a)
|
||||
std = gain / math.sqrt(fan)
|
||||
with paddle.no_grad():
|
||||
paddle.nn.initializer.Normal(0, std)(tensor)
|
||||
return tensor
|
||||
Reference in New Issue
Block a user