移动doc_dewarp

This commit is contained in:
2024-09-24 17:10:56 +08:00
parent 3438cf6e0e
commit 7647df7d74
21 changed files with 0 additions and 0 deletions

View File

@@ -0,0 +1,167 @@
input/
output/
runs/
*_dump/
*_log/
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/#use-with-ide
.pdm.toml
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/

View File

@@ -0,0 +1,34 @@
repos:
# Common hooks
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
hooks:
- id: check-added-large-files
- id: check-merge-conflict
- id: check-symlinks
- id: detect-private-key
- id: end-of-file-fixer
- id: trailing-whitespace
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
rev: v1.5.1
hooks:
- id: remove-crlf
- id: remove-tabs
name: Tabs remover (Python)
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
args: [--whitespaces-count, '4']
# For Python files
- repo: https://github.com/psf/black.git
rev: 23.3.0
hooks:
- id: black
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
- repo: https://github.com/pycqa/isort
rev: 5.11.5
hooks:
- id: isort
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.0.272
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix, --no-cache]

View File

@@ -0,0 +1,398 @@
import copy
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
from .extractor import BasicEncoder
from .position_encoding import build_position_encoding
from .weight_init import weight_init_
class attnLayer(nn.Layer):
def __init__(
self,
d_model,
nhead=8,
dim_feedforward=2048,
dropout=0.1,
activation="relu",
normalize_before=False,
):
super().__init__()
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
self.multihead_attn_list = nn.LayerList(
[
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
for i in range(2)
]
)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(p=dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2_list = nn.LayerList(
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
)
self.norm3 = nn.LayerNorm(d_model)
self.dropout1 = nn.Dropout(p=dropout)
self.dropout2_list = nn.LayerList(
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
)
self.dropout3 = nn.Dropout(p=dropout)
self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
return tensor if pos is None else tensor + pos
def forward_post(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
q = k = self.with_pos_embed(tgt, pos)
tgt2 = self.self_attn(
q.transpose((1, 0, 2)),
k.transpose((1, 0, 2)),
value=tgt.transpose((1, 0, 2)),
attn_mask=tgt_mask,
)
tgt2 = tgt2.transpose((1, 0, 2))
tgt = tgt + self.dropout1(tgt2)
tgt = self.norm1(tgt)
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
memory_list,
self.multihead_attn_list,
self.norm2_list,
self.dropout2_list,
memory_pos,
):
tgt2 = multihead_attn(
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
value=memory.transpose((1, 0, 2)),
attn_mask=memory_mask,
).transpose((1, 0, 2))
tgt = tgt + dropout2(tgt2)
tgt = norm2(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
tgt = tgt + self.dropout3(tgt2)
tgt = self.norm3(tgt)
return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, pos)
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
query=self.with_pos_embed(tgt2, pos),
key=self.with_pos_embed(memory, memory_pos),
value=memory,
attn_mask=memory_mask,
)
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory_list,
tgt_mask=None,
memory_mask=None,
pos=None,
memory_pos=None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
return self.forward_post(
tgt,
memory_list,
tgt_mask,
memory_mask,
pos,
memory_pos,
)
def _get_clones(module, N):
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
def _get_activation_fn(activation):
"""Return an activation function given a string"""
if activation == "relu":
return F.relu
if activation == "gelu":
return F.gelu
if activation == "glu":
return F.glu
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
class TransDecoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransDecoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return query_embed
class TransEncoder(nn.Layer):
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
super(TransEncoder, self).__init__()
attn_layer = attnLayer(hidden_dim)
self.layers = _get_clones(attn_layer, num_attn_layers)
self.position_embedding = build_position_encoding(hidden_dim)
def forward(self, image: paddle.Tensor):
pos = self.position_embedding(
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
)
b, c, h, w = image.shape
image = image.flatten(2).transpose(perm=[2, 0, 1])
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
for layer in self.layers:
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
return image
class FlowHead(nn.Layer):
def __init__(self, input_dim=128, hidden_dim=256):
super(FlowHead, self).__init__()
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
self.relu = nn.ReLU()
def forward(self, x):
return self.conv2(self.relu(self.conv1(x)))
class UpdateBlock(nn.Layer):
def __init__(self, hidden_dim: int = 128):
super(UpdateBlock, self).__init__()
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
self.mask = nn.Sequential(
nn.Conv2D(hidden_dim, 256, 3, padding=1),
nn.ReLU(),
nn.Conv2D(256, 64 * 9, 1, padding=0),
)
def forward(self, image, coords):
mask = 0.25 * self.mask(image)
dflow = self.flow_head(image)
coords = coords + dflow
return mask, coords
def coords_grid(batch, ht, wd):
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
return coords[None].tile([batch, 1, 1, 1])
def upflow8(flow, mode="bilinear"):
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
class OverlapPatchEmbed(nn.Layer):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
patch_size = (
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
)
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2D(
in_chans,
embed_dim,
patch_size,
stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2),
)
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
weight_init_(m, "trunc_normal_", std=0.02)
elif isinstance(m, nn.LayerNorm):
weight_init_(m, "Constant", value=1.0)
elif isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2)
perm = list(range(x.ndim))
perm[1] = 2
perm[2] = 1
x = x.transpose(perm=perm)
x = self.norm(x)
return x, H, W
class GeoTr(nn.Layer):
def __init__(self):
super(GeoTr, self).__init__()
self.hidden_dim = hdim = 256
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
for i in self.encoder_block:
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
for i in self.down_layer:
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
for i in self.decoder_block:
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
for i in self.up_layer:
self.__setattr__(
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
)
self.query_embed = nn.Embedding(81, self.hidden_dim)
self.update_block = UpdateBlock(self.hidden_dim)
def initialize_flow(self, img):
N, _, H, W = img.shape
coodslar = coords_grid(N, H, W)
coords0 = coords_grid(N, H // 8, W // 8)
coords1 = coords_grid(N, H // 8, W // 8)
return coodslar, coords0, coords1
def upsample_flow(self, flow, mask):
N, _, H, W = flow.shape
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
mask = F.softmax(mask, axis=2)
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
up_flow = paddle.sum(mask * up_flow, axis=2)
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
return up_flow.reshape([N, 2, 8 * H, 8 * W])
def forward(self, image):
fmap = self.fnet(image)
fmap = F.relu(fmap)
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
fmap3du_ = (
self.__getattr__(self.up_layer[0])(fmap3d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
fmap2du_ = (
self.__getattr__(self.up_layer[1])(fmap2d_)
.flatten(2)
.transpose(perm=[2, 0, 1])
)
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
coodslar, coords0, coords1 = self.initialize_flow(image)
coords1 = coords1.detach()
mask, coords1 = self.update_block(fmap_out, coords1)
flow_up = self.upsample_flow(coords1 - coords0, mask)
bm_up = coodslar + flow_up
return bm_up

View File

@@ -0,0 +1,73 @@
# DocTrPP: DocTr++ in PaddlePaddle
## Introduction
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
![demo](https://github.com/GreatV/DocTrPP/assets/17264618/4e491512-bfc4-4e69-a833-fd1c6e17158c)
## Requirements
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
## Training
1. Data Preparation
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
2. Training
```shell
sh train.sh
```
or
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--exist-ok \
--use-vdl
```
3. Load Trained Model and Continue Training
```shell
export OPENCV_IO_ENABLE_OPENEXR=1
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 2.5e-5 \
--resume "runs/train/DocTr++/weights/last.ckpt" \
--exist-ok \
--use-vdl
```
## Test and Inference
Test the dewarp result on a single image:
```shell
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
```
![document image rectification](https://raw.githubusercontent.com/greatv/DocTrPP/main/doc/imgs/document_image_rectification.jpg)
## Export to onnx
```
pip install paddle2onnx
python export.py -m ./best.ckpt --format onnx
```
## Model Download
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).

View File

@@ -0,0 +1,4 @@
from onnxruntime import InferenceSession
DOC_TR = InferenceSession("model/dewarp_model/doc_tr_pp.onnx",
providers=["CUDAExecutionProvider"], provider_options=[{"device_id": 0}])

View File

@@ -0,0 +1,133 @@
import argparse
import os
import random
import cv2
import hdf5storage as h5
import matplotlib.pyplot as plt
import numpy as np
import paddle
import paddle.nn.functional as F
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
nargs="?",
type=str,
default="~/datasets/doc3d/",
help="Path to the downloaded dataset",
)
parser.add_argument(
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
)
parser.add_argument(
"--output",
nargs="?",
type=str,
default="output.png",
help="Output filename for the image",
)
args = parser.parse_args()
root = os.path.expanduser(args.data_root)
folder = args.folder
dirname = os.path.join(root, "img", str(folder))
choices = [f for f in os.listdir(dirname) if "png" in f]
fname = random.choice(choices)
# Read Image
img_path = os.path.join(dirname, fname)
img = cv2.imread(img_path)
# Read 3D Coords
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# scale wc
# value obtained from the entire dataset
xmx, xmn, ymx, ymn, zmx, zmn = (
1.2539363,
-1.2442188,
1.2396319,
-1.2289206,
0.6436657,
-0.67492497,
)
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
# Read Backward Map
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
bm = h5.loadmat(bm_path)["bm"]
# Read UV Map
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Depth Map
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
# do some clipping and scaling to display it
dmap[dmap > 30.0] = 30
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
# Read Normal Map
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read Albedo
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
alb = cv2.imread(alb_path)
# Read Checkerboard Image
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
recon = cv2.imread(recon_path)
# Display image and GTs
# use the backward mapping to dewarp the image
# scale bm to -1.0 to 1.0
bm_ = bm / np.array([448, 448])
bm_ = (bm_ - 0.5) * 2
bm_ = np.reshape(bm_, (1, 448, 448, 2))
bm_ = paddle.to_tensor(bm_, dtype="float32")
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
img_ = np.expand_dims(img_, 0)
img_ = paddle.to_tensor(img_, dtype="float32")
uw = F.grid_sample(img_, bm_)
uw = uw[0].numpy().transpose((1, 2, 0))
f, axrr = plt.subplots(2, 5)
for ax in axrr:
for a in ax:
a.set_xticks([])
a.set_yticks([])
axrr[0][0].imshow(img)
axrr[0][0].title.set_text("image")
axrr[0][1].imshow(wc)
axrr[0][1].title.set_text("3D coords")
axrr[0][2].imshow(bm[:, :, 0])
axrr[0][2].title.set_text("bm 0")
axrr[0][3].imshow(bm[:, :, 1])
axrr[0][3].title.set_text("bm 1")
if uv is None:
uv = np.zeros_like(img)
axrr[0][4].imshow(uv)
axrr[0][4].title.set_text("uv map")
axrr[1][0].imshow(dmap)
axrr[1][0].title.set_text("depth map")
axrr[1][1].imshow(norm)
axrr[1][1].title.set_text("normal map")
axrr[1][2].imshow(alb)
axrr[1][2].title.set_text("albedo")
axrr[1][3].imshow(recon)
axrr[1][3].title.set_text("checkerboard")
axrr[1][4].imshow(uw)
axrr[1][4].title.set_text("gt unwarped")
plt.tight_layout()
plt.savefig(args.output)

View File

@@ -0,0 +1,21 @@
import cv2
import numpy as np
import paddle
from . import DOC_TR
from .utils import to_tensor, to_image
def dewarp_image(image):
img = cv2.resize(image, (288, 288)).astype(np.float32)
y = to_tensor(image)
img = np.transpose(img, (2, 0, 1))
bm = DOC_TR.run(None, {"image": img[None,]})[0]
bm = paddle.to_tensor(bm)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = np.transpose(bm, (0, 2, 3, 1))
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
return to_image(out)

View File

@@ -0,0 +1,161 @@
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"

Binary file not shown.

After

Width:  |  Height:  |  Size: 76 KiB

View File

@@ -0,0 +1,129 @@
import collections
import os
import random
import albumentations as A
import cv2
import hdf5storage as h5
import numpy as np
import paddle
from paddle import io
# Set random seed
random.seed(12345678)
class Doc3dDataset(io.Dataset):
def __init__(self, root, split="train", is_augment=False, image_size=512):
self.root = os.path.expanduser(root)
self.split = split
self.is_augment = is_augment
self.files = collections.defaultdict(list)
self.image_size = (
image_size if isinstance(image_size, tuple) else (image_size, image_size)
)
# Augmentation
self.augmentation = A.Compose(
[
A.ColorJitter(),
]
)
for split in ["train", "val"]:
path = os.path.join(self.root, split + ".txt")
file_list = []
with open(path, "r") as file:
file_list = [file_id.rstrip() for file_id in file.readlines()]
self.files[split] = file_list
def __len__(self):
return len(self.files[self.split])
def __getitem__(self, index):
image_name = self.files[self.split][index]
# Read image
image_path = os.path.join(self.root, "img", image_name + ".png")
image = cv2.imread(image_path)
# Read 3D Coordinates
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
# Read backward map
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
bm = h5.loadmat(bm_path)["bm"]
image, bm = self.transform(wc, bm, image)
return image, bm
def tight_crop(self, wc: np.ndarray):
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
np.uint8
)
mask_size = mask.shape
[y, x] = mask.nonzero()
min_x = min(x)
max_x = max(x)
min_y = min(y)
max_y = max(y)
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
s = 10
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
cx1 = random.randint(0, 2 * s)
cx2 = random.randint(0, 2 * s) + 1
cy1 = random.randint(0, 2 * s)
cy2 = random.randint(0, 2 * s) + 1
wc = wc[cy1:-cy2, cx1:-cx2, :]
top: int = min_y - s + cy1
bottom: int = mask_size[0] - max_y - s + cy2
left: int = min_x - s + cx1
right: int = mask_size[1] - max_x - s + cx2
top = np.clip(top, 0, mask_size[0])
bottom = np.clip(bottom, 1, mask_size[0] - 1)
left = np.clip(left, 0, mask_size[1])
right = np.clip(right, 1, mask_size[1] - 1)
return wc, top, bottom, left, right
def transform(self, wc, bm, img):
wc, top, bottom, left, right = self.tight_crop(wc)
img = img[top:-bottom, left:-right, :]
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.is_augment:
img = self.augmentation(image=img)["image"]
# resize image
img = cv2.resize(img, self.image_size)
img = img.astype(np.float32) / 255.0
img = img.transpose(2, 0, 1)
# resize bm
bm = bm.astype(np.float32)
bm[:, :, 1] = bm[:, :, 1] - top
bm[:, :, 0] = bm[:, :, 0] - left
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
bm0 = bm0 * self.image_size[0]
bm1 = bm1 * self.image_size[1]
bm = np.stack([bm0, bm1], axis=-1)
img = paddle.to_tensor(img).astype(dtype="float32")
bm = paddle.to_tensor(bm).astype(dtype="float32")
return img, bm

View File

@@ -0,0 +1,66 @@
import argparse
import os
import paddle
from GeoTr import GeoTr
def export(args):
model_path = args.model
imgsz = args.imgsz
format = args.format
model = GeoTr()
checkpoint = paddle.load(model_path)
model.set_state_dict(checkpoint["model"])
model.eval()
dirname = os.path.dirname(model_path)
if format == "static" or format == "onnx":
model = paddle.jit.to_static(
model,
input_spec=[
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
],
full_graph=True,
)
paddle.jit.save(model, os.path.join(dirname, "model"))
if format == "onnx":
onnx_path = os.path.join(dirname, "model.onnx")
os.system(
f"paddle2onnx --model_dir {dirname}"
" --model_filename model.pdmodel"
" --params_filename model.pdiparams"
f" --save_file {onnx_path}"
" --opset_version 11"
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="export model")
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--imgsz", type=int, default=288, help="The size of input image"
)
parser.add_argument(
"--format",
type=str,
default="static",
help="The format of exported model, which can be static or onnx",
)
args = parser.parse_args()
export(args)

View File

@@ -0,0 +1,110 @@
import paddle.nn as nn
from .weight_init import weight_init_
class ResidualBlock(nn.Layer):
"""Residual Block with custom normalization."""
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
self.relu = nn.ReLU()
if norm_fn == "group":
num_groups = planes // 8
self.norm1 = nn.GroupNorm(num_groups, planes)
self.norm2 = nn.GroupNorm(num_groups, planes)
if not stride == 1:
self.norm3 = nn.GroupNorm(num_groups, planes)
elif norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(planes)
self.norm2 = nn.BatchNorm2D(planes)
if not stride == 1:
self.norm3 = nn.BatchNorm2D(planes)
elif norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(planes)
self.norm2 = nn.InstanceNorm2D(planes)
if not stride == 1:
self.norm3 = nn.InstanceNorm2D(planes)
elif norm_fn == "none":
self.norm1 = nn.Sequential()
self.norm2 = nn.Sequential()
if not stride == 1:
self.norm3 = nn.Sequential()
if stride == 1:
self.downsample = None
else:
self.downsample = nn.Sequential(
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
)
def forward(self, x):
y = x
y = self.relu(self.norm1(self.conv1(y)))
y = self.relu(self.norm2(self.conv2(y)))
if self.downsample is not None:
x = self.downsample(x)
return self.relu(x + y)
class BasicEncoder(nn.Layer):
"""Basic Encoder with custom normalization."""
def __init__(self, output_dim=128, norm_fn="batch"):
super(BasicEncoder, self).__init__()
self.norm_fn = norm_fn
if self.norm_fn == "group":
self.norm1 = nn.GroupNorm(8, 64)
elif self.norm_fn == "batch":
self.norm1 = nn.BatchNorm2D(64)
elif self.norm_fn == "instance":
self.norm1 = nn.InstanceNorm2D(64)
elif self.norm_fn == "none":
self.norm1 = nn.Sequential()
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
self.relu1 = nn.ReLU()
self.in_planes = 64
self.layer1 = self._make_layer(64, stride=1)
self.layer2 = self._make_layer(128, stride=2)
self.layer3 = self._make_layer(192, stride=2)
self.conv2 = nn.Conv2D(192, output_dim, 1)
for m in self.sublayers():
if isinstance(m, nn.Conv2D):
weight_init_(
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
)
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
weight_init_(m, "Constant", value=1, bias_value=0.0)
def _make_layer(self, dim, stride=1):
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
layers = layer1, layer2
self.in_planes = dim
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.norm1(x)
x = self.relu1(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.conv2(x)
return x

View File

@@ -0,0 +1,48 @@
import copy
import os
import matplotlib.pyplot as plt
import paddle.optimizer as optim
from GeoTr import GeoTr
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
"""
Plot the learning rate scheduler
"""
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
lr = []
for _ in range(epochs):
for _ in range(30):
lr.append(scheduler.get_lr())
optimizer.step()
scheduler.step()
epoch = [float(i) / 30.0 for i in range(len(lr))]
plt.figure()
plt.plot(epoch, lr, ".-", label="Learning Rate")
plt.xlabel("epoch")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Scheduler")
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
plt.close()
if __name__ == "__main__":
model = GeoTr()
schaduler = optim.lr.OneCycleLR(
max_learning_rate=1e-4,
total_steps=1950,
phase_pct=0.1,
end_learning_rate=1e-4 / 2.5e5,
)
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
plot_lr_scheduler(
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
)

View File

@@ -0,0 +1,124 @@
import math
from typing import Optional
import paddle
import paddle.nn as nn
import paddle.nn.initializer as init
class NestedTensor(object):
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
self.tensors = tensors
self.mask = mask
def decompose(self):
return self.tensors, self.mask
def __repr__(self):
return str(self.tensors)
class PositionEmbeddingSine(nn.Layer):
"""
This is a more standard version of the position embedding, very similar to the one
used by the Attention is all you need paper, generalized to work on images.
"""
def __init__(
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
):
super().__init__()
self.num_pos_feats = num_pos_feats
self.temperature = temperature
self.normalize = normalize
if scale is not None and normalize is False:
raise ValueError("normalize should be True if scale is passed")
if scale is None:
scale = 2 * math.pi
self.scale = scale
def forward(self, mask):
assert mask is not None
y_embed = mask.cumsum(axis=1, dtype="float32")
x_embed = mask.cumsum(axis=2, dtype="float32")
if self.normalize:
eps = 1e-06
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
pos_x = x_embed[:, :, :, None] / dim_t
pos_y = y_embed[:, :, :, None] / dim_t
pos_x = paddle.stack(
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos_y = paddle.stack(
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
).flatten(3)
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
return pos
class PositionEmbeddingLearned(nn.Layer):
"""
Absolute pos embedding, learned.
"""
def __init__(self, num_pos_feats=256):
super().__init__()
self.row_embed = nn.Embedding(50, num_pos_feats)
self.col_embed = nn.Embedding(50, num_pos_feats)
self.reset_parameters()
def reset_parameters(self):
init_Constant = init.Uniform()
init_Constant(self.row_embed.weight)
init_Constant(self.col_embed.weight)
def forward(self, tensor_list: NestedTensor):
x = tensor_list.tensors
h, w = x.shape[-2:]
i = paddle.arange(end=w)
j = paddle.arange(end=h)
x_emb = self.col_embed(i)
y_emb = self.row_embed(j)
pos = (
paddle.concat(
[
x_emb.unsqueeze(0).tile([h, 1, 1]),
y_emb.unsqueeze(1).tile([1, w, 1]),
],
axis=-1,
)
.transpose([2, 0, 1])
.unsqueeze(0)
.tile([x.shape[0], 1, 1, 1])
)
return pos
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
N_steps = hidden_dim // 2
if position_embedding in ("v2", "sine"):
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
elif position_embedding in ("v3", "learned"):
position_embedding = PositionEmbeddingLearned(N_steps)
else:
raise ValueError(f"not supported {position_embedding}")
return position_embedding

View File

@@ -0,0 +1,69 @@
import argparse
import cv2
import paddle
from GeoTr import GeoTr
from utils import to_image, to_tensor
def run(args):
image_path = args.image
model_path = args.model
output_path = args.output
checkpoint = paddle.load(model_path)
state_dict = checkpoint["model"]
model = GeoTr()
model.set_state_dict(state_dict)
model.eval()
img_org = cv2.imread(image_path)
img = cv2.resize(img_org, (288, 288))
x = to_tensor(img)
y = to_tensor(img_org)
bm = model(x)
bm = paddle.nn.functional.interpolate(
bm, y.shape[2:], mode="bilinear", align_corners=False
)
bm_nhwc = bm.transpose([0, 2, 3, 1])
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
out_image = to_image(out)
cv2.imwrite(output_path, out_image)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="predict")
parser.add_argument(
"--image",
"-i",
nargs="?",
type=str,
default="",
help="The path of image",
)
parser.add_argument(
"--model",
"-m",
nargs="?",
type=str,
default="",
help="The path of model",
)
parser.add_argument(
"--output",
"-o",
nargs="?",
type=str,
default="",
help="The path of output",
)
args = parser.parse_args()
print(args)
run(args)

View File

@@ -0,0 +1,7 @@
hdf5storage
loguru
numpy
scipy
opencv-python
matplotlib
albumentations

View File

@@ -0,0 +1,57 @@
import argparse
import glob
import os
import random
from pathlib import Path
from loguru import logger
random.seed(1234567)
def run(args):
data_root = os.path.expanduser(args.data_root)
ratio = args.train_ratio
data_path = os.path.join(data_root, "img", "*", "*.png")
img_list = glob.glob(data_path, recursive=True)
sorted(img_list)
random.shuffle(img_list)
train_size = int(len(img_list) * ratio)
train_text_path = os.path.join(data_root, "train.txt")
with open(train_text_path, "w") as file:
for item in img_list[:train_size]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
val_text_path = os.path.join(data_root, "val.txt")
with open(val_text_path, "w") as file:
for item in img_list[train_size:]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
logger.info(f"TRAIN LABEL: {train_text_path}")
logger.info(f"VAL LABEL: {val_text_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="~/datasets/doc3d",
help="Data path to load data",
)
parser.add_argument(
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
)
args = parser.parse_args()
logger.info(args)
run(args)

View File

@@ -0,0 +1,381 @@
import argparse
import inspect
import os
import random
from pathlib import Path
from typing import Optional
import numpy as np
import paddle
import paddle.nn as nn
import paddle.optimizer as optim
from loguru import logger
from paddle.io import DataLoader
from paddle.nn import functional as F
from paddle_msssim import ms_ssim, ssim
from doc3d_dataset import Doc3dDataset
from GeoTr import GeoTr
from utils import to_image
FILE = Path(__file__).resolve()
ROOT = FILE.parents[0]
RANK = int(os.getenv("RANK", -1))
def init_seeds(seed=0, deterministic=False):
random.seed(seed)
np.random.seed(seed)
paddle.seed(seed)
if deterministic:
os.environ["FLAGS_cudnn_deterministic"] = "1"
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
os.environ["PYTHONHASHSEED"] = str(seed)
def colorstr(*input):
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
# i.e. colorstr('blue', 'hello world')
*args, string = (
input if len(input) > 1 else ("blue", "bold", input[0])
) # color arguments, string
colors = {
"black": "\033[30m", # basic colors
"red": "\033[31m",
"green": "\033[32m",
"yellow": "\033[33m",
"blue": "\033[34m",
"magenta": "\033[35m",
"cyan": "\033[36m",
"white": "\033[37m",
"bright_black": "\033[90m", # bright colors
"bright_red": "\033[91m",
"bright_green": "\033[92m",
"bright_yellow": "\033[93m",
"bright_blue": "\033[94m",
"bright_magenta": "\033[95m",
"bright_cyan": "\033[96m",
"bright_white": "\033[97m",
"end": "\033[0m", # misc
"bold": "\033[1m",
"underline": "\033[4m",
}
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
# Print function arguments (optional args dict)
x = inspect.currentframe().f_back # previous frame
file, _, func, _, _ = inspect.getframeinfo(x)
if args is None: # get args automatically
args, _, _, frm = inspect.getargvalues(x)
args = {k: v for k, v in frm.items() if k in args}
try:
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
except ValueError:
file = Path(file).stem
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
def increment_path(path, exist_ok=False, sep="", mkdir=False):
# Increment file or directory path,
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
path = Path(path) # os-agnostic
if path.exists() and not exist_ok:
path, suffix = (
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
)
for n in range(2, 9999):
p = f"{path}{sep}{n}{suffix}" # increment path
if not os.path.exists(p):
break
path = Path(p)
if mkdir:
path.mkdir(parents=True, exist_ok=True) # make directory
return path
def train(args):
save_dir = Path(args.save_dir)
use_vdl = args.use_vdl
if use_vdl:
from visualdl import LogWriter
log_dir = save_dir / "vdl"
vdl_writer = LogWriter(str(log_dir))
# Directories
weights_dir = save_dir / "weights"
weights_dir.parent.mkdir(parents=True, exist_ok=True)
last = weights_dir / "last.ckpt"
best = weights_dir / "best.ckpt"
# Hyperparameters
# Config
init_seeds(args.seed)
# Train loader
train_dataset = Doc3dDataset(
args.data_root,
split="train",
is_augment=True,
image_size=args.img_size,
)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
)
# Validation loader
val_dataset = Doc3dDataset(
args.data_root,
split="val",
is_augment=False,
image_size=args.img_size,
)
val_loader = DataLoader(
val_dataset, batch_size=args.batch_size, num_workers=args.workers
)
# Model
model = GeoTr()
if use_vdl:
vdl_writer.add_graph(
model,
input_spec=[
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
],
)
# Data Parallel Mode
if RANK == -1 and paddle.device.cuda.device_count() > 1:
model = paddle.DataParallel(model)
# Scheduler
scheduler = optim.lr.OneCycleLR(
max_learning_rate=args.lr,
total_steps=args.epochs * len(train_loader),
phase_pct=0.1,
end_learning_rate=args.lr / 2.5e5,
)
# Optimizer
optimizer = optim.AdamW(
learning_rate=scheduler,
parameters=model.parameters(),
)
# loss function
l1_loss_fn = nn.L1Loss()
mse_loss_fn = nn.MSELoss()
# Resume
best_fitness, start_epoch = 0.0, 0
if args.resume:
ckpt = paddle.load(args.resume)
model.set_state_dict(ckpt["model"])
optimizer.set_state_dict(ckpt["optimizer"])
scheduler.set_state_dict(ckpt["scheduler"])
best_fitness = ckpt["best_fitness"]
start_epoch = ckpt["epoch"] + 1
# Train
for epoch in range(start_epoch, args.epochs):
model.train()
for i, (img, target) in enumerate(train_loader):
img = paddle.to_tensor(img) # NCHW
target = paddle.to_tensor(target) # NHWC
pred = model(img) # NCHW
pred_nhwc = pred.transpose([0, 2, 3, 1])
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
if use_vdl:
vdl_writer.add_scalar(
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
)
vdl_writer.add_scalar(
"Train/Learning Rate",
float(scheduler.get_lr()),
epoch * len(train_loader) + i,
)
loss.backward()
optimizer.step()
scheduler.step()
optimizer.clear_grad()
if i % 10 == 0:
logger.info(
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
)
# Validation
model.eval()
with paddle.no_grad():
avg_ssim = paddle.zeros([])
avg_ms_ssim = paddle.zeros([])
avg_l1_loss = paddle.zeros([])
avg_mse_loss = paddle.zeros([])
for i, (img, target) in enumerate(val_loader):
img = paddle.to_tensor(img)
target = paddle.to_tensor(target)
pred = model(img)
pred_nhwc = pred.transpose([0, 2, 3, 1])
# predict image
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
# calculate ssim
ssim_val = ssim(out, out_gt, data_range=1.0)
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
loss = l1_loss_fn(pred_nhwc, target)
mse_loss = mse_loss_fn(pred_nhwc, target)
# calculate fitness
avg_ssim += ssim_val
avg_ms_ssim += ms_ssim_val
avg_l1_loss += loss
avg_mse_loss += mse_loss
if i % 10 == 0:
logger.info(
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
)
if use_vdl and i == 0:
img_0 = to_image(out[0])
img_gt_0 = to_image(out_gt[0])
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
img_1 = to_image(out[1])
img_gt_1 = to_image(out_gt[1])
img_gt_1 = img_gt_1.astype("uint8")
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
img_2 = to_image(out[2])
img_gt_2 = to_image(out_gt[2])
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
avg_ssim /= len(val_loader)
avg_ms_ssim /= len(val_loader)
avg_l1_loss /= len(val_loader)
avg_mse_loss /= len(val_loader)
if use_vdl:
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
# Save
ckpt = {
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"scheduler": scheduler.state_dict(),
"best_fitness": best_fitness,
"epoch": epoch,
}
paddle.save(ckpt, str(last))
if best_fitness < avg_ssim:
best_fitness = avg_ssim
paddle.save(ckpt, str(best))
if use_vdl:
vdl_writer.close()
def main(args):
print_args(vars(args))
args.save_dir = str(
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
)
train(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Hyperparams")
parser.add_argument(
"--data-root",
nargs="?",
type=str,
default="~/datasets/doc3d",
help="The root path of the dataset",
)
parser.add_argument(
"--img-size",
nargs="?",
type=int,
default=288,
help="The size of the input image",
)
parser.add_argument(
"--epochs",
nargs="?",
type=int,
default=65,
help="The number of training epochs",
)
parser.add_argument(
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
)
parser.add_argument(
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
)
parser.add_argument(
"--resume",
nargs="?",
type=str,
default=None,
help="Path to previous saved model to restart from",
)
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
parser.add_argument(
"--project", default=ROOT / "runs/train", help="save to project/name"
)
parser.add_argument("--name", default="exp", help="save to project/name")
parser.add_argument(
"--exist-ok",
action="store_true",
help="existing project/name ok, do not increment",
)
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
args = parser.parse_args()
main(args)

View File

@@ -0,0 +1,10 @@
export OPENCV_IO_ENABLE_OPENEXR=1
export FLAGS_logtostderr=0
export CUDA_VISIBLE_DEVICES=0
python train.py --img-size 288 \
--name "DocTr++" \
--batch-size 12 \
--lr 1e-4 \
--exist-ok \
--use-vdl

View File

@@ -0,0 +1,67 @@
import numpy as np
import paddle
from paddle.nn import functional as F
def to_tensor(img: np.ndarray):
"""
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
Args:
img (numpy.ndarray): The input image as a numpy array.
Returns:
out (paddle.Tensor): The output tensor.
"""
img = img[:, :, ::-1]
img = img.astype("float32") / 255.0
img = img.transpose(2, 0, 1)
out: paddle.Tensor = paddle.to_tensor(img)
out = paddle.unsqueeze(out, axis=0)
return out
def to_image(x: paddle.Tensor):
"""
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
Args:
x (paddle.Tensor): The input tensor.
Returns:
out (numpy.ndarray): The output image as a numpy array.
"""
out: np.ndarray = x.squeeze().numpy()
out = out.transpose(1, 2, 0)
out = out * 255.0
out = out.astype("uint8")
out = out[:, :, ::-1]
return out
def unwarp(img, bm, bm_data_format="NCHW"):
"""
Unwarp an image using a flow field.
Args:
img (paddle.Tensor): The input image.
bm (paddle.Tensor): The flow field.
Returns:
out (paddle.Tensor): The output image.
"""
_, _, h, w = img.shape
if bm_data_format == "NHWC":
bm = bm.transpose([0, 3, 1, 2])
# NCHW
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
# NHWC
bm = bm.transpose([0, 2, 3, 1])
# NCHW
out = F.grid_sample(img, bm)
return out

View File

@@ -0,0 +1,152 @@
import math
import numpy as np
import paddle
import paddle.nn.initializer as init
from scipy import special
def weight_init_(
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
):
"""
In-place params init function.
Usage:
.. code-block:: python
import paddle
import numpy as np
data = np.ones([3, 4], dtype='float32')
linear = paddle.nn.Linear(4, 4)
input = paddle.to_tensor(data)
print(linear.weight)
linear(input)
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
print(linear.weight)
"""
if hasattr(layer, "weight") and layer.weight is not None:
getattr(init, func)(**kwargs)(layer.weight)
if weight_name is not None:
# override weight name
layer.weight.name = weight_name
if hasattr(layer, "bias") and layer.bias is not None:
init.Constant(bias_value)(layer.bias)
if bias_name is not None:
# override bias name
layer.bias.name = bias_name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with paddle.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
tmp = np.random.uniform(
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
).astype(np.float32)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tmp = special.erfinv(tmp)
# Transform to proper mean, std
tmp *= std * math.sqrt(2.0)
tmp += mean
# Clamp to ensure it's in the proper range
tmp = np.clip(tmp, a, b)
tensor.set_value(paddle.to_tensor(tmp))
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor "
"with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
)
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with paddle.no_grad():
paddle.nn.initializer.Normal(0, std)(tensor)
return tensor