diff --git a/dewarp/dewarp.py b/dewarp/dewarp.py deleted file mode 100644 index d8bce43..0000000 --- a/dewarp/dewarp.py +++ /dev/null @@ -1,96 +0,0 @@ -import os - -import cv2 -import numpy as np -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable - -from dewarp.models import get_model -from dewarp.utils import convert_state_dict - -DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - - -def unwarp(img, bm): - w, h = img.shape[0], img.shape[1] - bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :] - bm0 = cv2.blur(bm[:, :, 0], (3, 3)) - bm1 = cv2.blur(bm[:, :, 1], (3, 3)) - bm0 = cv2.resize(bm0, (h, w)) - bm1 = cv2.resize(bm1, (h, w)) - bm = np.stack([bm0, bm1], axis=-1) - bm = np.expand_dims(bm, 0) - bm = torch.from_numpy(bm).double() - - img = img.astype(float) / 255.0 - img = img.transpose((2, 0, 1)) - img = np.expand_dims(img, 0) - img = torch.from_numpy(img).double() - - res = F.grid_sample(input=img, grid=bm, align_corners=True) - res = res[0].numpy().transpose((1, 2, 0)) - - return res - - -def dewarp_image(image): - wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl" - bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl" - wc_model_file_name = os.path.split(wc_model_path)[1] - wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')] - - bm_model_file_name = os.path.split(bm_model_path)[1] - bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')] - - wc_n_classes = 3 - bm_n_classes = 2 - - wc_img_size = (256, 256) - bm_img_size = (128, 128) - - # Setup image - imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - img = cv2.resize(imgorg, wc_img_size) - - img = img[:, :, ::-1] - img = img.astype(float) / 255.0 - img = img.transpose(2, 0, 1) # NHWC -> NCHW - img = np.expand_dims(img, 0) - img = torch.from_numpy(img).float() - - # Predict - htan = nn.Hardtanh(0, 1.0) - wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3) - if DEVICE.type == 'cpu': - wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state']) - else: - wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state']) - wc_model.load_state_dict(wc_state) - wc_model.eval() - bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3) - if DEVICE.type == 'cpu': - bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state']) - else: - bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state']) - bm_model.load_state_dict(bm_state) - bm_model.eval() - - if torch.cuda.is_available(): - wc_model.cuda() - bm_model.cuda() - images = Variable(img.cuda()) - else: - images = Variable(img) - - with torch.no_grad(): - wc_outputs = wc_model(images) - pred_wc = htan(wc_outputs) - bm_input = F.interpolate(pred_wc, bm_img_size) - outputs_bm = bm_model(bm_input) - - # call unwarp - uwpred = unwarp(imgorg, outputs_bm) - uwpred = (uwpred * 255).astype(np.uint8) - return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR) diff --git a/dewarp/models/__init__.py b/dewarp/models/__init__.py deleted file mode 100644 index 50b9225..0000000 --- a/dewarp/models/__init__.py +++ /dev/null @@ -1,24 +0,0 @@ -from dewarp.models.densenetccnl import dnetccnl -from dewarp.models.unetnc import UnetGenerator - - -def get_model(name, n_classes=1, in_channels=3): - model = _get_model_instance(name) - - if name == 'dnetccnl': - model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32) - elif name == 'unetnc': - model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) - else: - model = model(n_classes=n_classes) - return model - - -def _get_model_instance(name): - try: - return { - 'dnetccnl': dnetccnl, - 'unetnc': UnetGenerator, - }[name] - except: - print('Model {} not available'.format(name)) diff --git a/dewarp/models/densenetccnl.py b/dewarp/models/densenetccnl.py deleted file mode 100644 index d92ed92..0000000 --- a/dewarp/models/densenetccnl.py +++ /dev/null @@ -1,243 +0,0 @@ -# Densenet decoder encoder with intermediate fully connected layers and dropout - -import numpy as np -import torch -import torch.nn as nn - - -def add_coordConv_channels(t): - n, c, h, w = t.size() - xx_channel = np.ones((h, w)) - xx_range = np.array(range(h)) - xx_range = np.expand_dims(xx_range, -1) - xx_coord = xx_channel * xx_range - yy_coord = xx_coord.transpose() - - xx_coord = xx_coord / (h - 1) - yy_coord = yy_coord / (h - 1) - xx_coord = xx_coord * 2 - 1 - yy_coord = yy_coord * 2 - 1 - xx_coord = torch.from_numpy(xx_coord).float() - yy_coord = torch.from_numpy(yy_coord).float() - - if t.is_cuda: - xx_coord = xx_coord.cuda() - yy_coord = yy_coord.cuda() - - xx_coord = xx_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1) - yy_coord = yy_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1) - - t_cc = torch.cat((t, xx_coord, yy_coord), dim=1) - - return t_cc - - -class DenseBlockEncoder(nn.Module): - def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): - super(DenseBlockEncoder, self).__init__() - assert (n_convs > 0) - - self.n_channels = n_channels - self.n_convs = n_convs - self.layers = nn.ModuleList() - for i in range(n_convs): - self.layers.append(nn.Sequential( - nn.BatchNorm2d(n_channels), - activation(*args), - nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), )) - - def forward(self, inputs): - outputs = [] - - for i, layer in enumerate(self.layers): - if i > 0: - next_output = 0 - for no in outputs: - next_output = next_output + no - outputs.append(next_output) - else: - outputs.append(layer(inputs)) - return outputs[-1] - - -# Dense block in encoder. -class DenseBlockDecoder(nn.Module): - def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): - super(DenseBlockDecoder, self).__init__() - assert (n_convs > 0) - - self.n_channels = n_channels - self.n_convs = n_convs - self.layers = nn.ModuleList() - for i in range(n_convs): - self.layers.append(nn.Sequential( - nn.BatchNorm2d(n_channels), - activation(*args), - nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), )) - - def forward(self, inputs): - outputs = [] - - for i, layer in enumerate(self.layers): - if i > 0: - next_output = 0 - for no in outputs: - next_output = next_output + no - outputs.append(next_output) - else: - outputs.append(layer(inputs)) - return outputs[-1] - - -class DenseTransitionBlockEncoder(nn.Module): - def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]): - super(DenseTransitionBlockEncoder, self).__init__() - self.n_channels_in = n_channels_in - self.n_channels_out = n_channels_out - self.mp = mp - self.main = nn.Sequential( - nn.BatchNorm2d(n_channels_in), - activation(*args), - nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False), - nn.MaxPool2d(mp), - ) - - def forward(self, inputs): - return self.main(inputs) - - -class DenseTransitionBlockDecoder(nn.Module): - def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]): - super(DenseTransitionBlockDecoder, self).__init__() - self.n_channels_in = n_channels_in - self.n_channels_out = n_channels_out - self.main = nn.Sequential( - nn.BatchNorm2d(n_channels_in), - activation(*args), - nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False), - ) - - def forward(self, inputs): - return self.main(inputs) - - -## Dense encoders and decoders for image of size 128 128 -class waspDenseEncoder128(nn.Module): - def __init__(self, nc=1, ndf=32, ndim=128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, - f_args=[]): - super(waspDenseEncoder128, self).__init__() - self.ndim = ndim - - self.main = nn.Sequential( - # input is (nc) x 128 x 128 - nn.BatchNorm2d(nc), - nn.ReLU(True), - nn.Conv2d(nc, ndf, 4, stride=2, padding=1), - - # state size. (ndf) x 64 x 64 - DenseBlockEncoder(ndf, 6), - DenseTransitionBlockEncoder(ndf, ndf * 2, 2, activation=activation, args=args), - - # state size. (ndf*2) x 32 x 32 - DenseBlockEncoder(ndf * 2, 12), - DenseTransitionBlockEncoder(ndf * 2, ndf * 4, 2, activation=activation, args=args), - - # state size. (ndf*4) x 16 x 16 - DenseBlockEncoder(ndf * 4, 16), - DenseTransitionBlockEncoder(ndf * 4, ndf * 8, 2, activation=activation, args=args), - - # state size. (ndf*4) x 8 x 8 - DenseBlockEncoder(ndf * 8, 16), - DenseTransitionBlockEncoder(ndf * 8, ndf * 8, 2, activation=activation, args=args), - - # state size. (ndf*8) x 4 x 4 - DenseBlockEncoder(ndf * 8, 16), - DenseTransitionBlockEncoder(ndf * 8, ndim, 4, activation=activation, args=args), - f_activation(*f_args), - ) - - def forward(self, input): - input = add_coordConv_channels(input) - output = self.main(input).view(-1, self.ndim) - # print(output.size()) - return output - - -class waspDenseDecoder128(nn.Module): - def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, - f_args=[]): - super(waspDenseDecoder128, self).__init__() - self.main = nn.Sequential( - # input is Z, going into convolution - nn.BatchNorm2d(nz), - activation(*args), - nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), - - # state size. (ngf*8) x 4 x 4 - DenseBlockDecoder(ngf * 8, 16), - DenseTransitionBlockDecoder(ngf * 8, ngf * 8), - - # state size. (ngf*4) x 8 x 8 - DenseBlockDecoder(ngf * 8, 16), - DenseTransitionBlockDecoder(ngf * 8, ngf * 4), - - # state size. (ngf*2) x 16 x 16 - DenseBlockDecoder(ngf * 4, 12), - DenseTransitionBlockDecoder(ngf * 4, ngf * 2), - - # state size. (ngf) x 32 x 32 - DenseBlockDecoder(ngf * 2, 6), - DenseTransitionBlockDecoder(ngf * 2, ngf), - - # state size. (ngf) x 64 x 64 - DenseBlockDecoder(ngf, 6), - DenseTransitionBlockDecoder(ngf, ngf), - - # state size (ngf) x 128 x 128 - nn.BatchNorm2d(ngf), - activation(*args), - nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False), - f_activation(*f_args), - ) - # self.smooth=nn.Sequential( - # nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False), - # f_activation(*f_args), - # ) - - def forward(self, inputs): - # return self.smooth(self.main(inputs)) - return self.main(inputs) - - -class dnetccnl(nn.Module): - # in_channels -> nc | encoder first layer - # filters -> ndf | encoder first layer - # img_size(h,w) -> ndim - # out_channels -> optical flow (x,y) - - def __init__(self, img_size=128, in_channels=1, out_channels=2, filters=32, fc_units=100): - super(dnetccnl, self).__init__() - self.nc = in_channels - self.nf = filters - self.ndim = img_size - self.oc = out_channels - self.fcu = fc_units - - self.encoder = waspDenseEncoder128(nc=self.nc + 2, ndf=self.nf, ndim=self.ndim) - self.decoder = waspDenseDecoder128(nz=self.ndim, nc=self.oc, ngf=self.nf) - # self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu), - # nn.ReLU(True), - # nn.Dropout(0.25), - # nn.Linear(self.fcu,self.ndim), - # nn.ReLU(True), - # nn.Dropout(0.25), - # ) - - def forward(self, inputs): - encoded = self.encoder(inputs) - encoded = encoded.unsqueeze(-1).unsqueeze(-1) - decoded = self.decoder(encoded) - # print torch.max(decoded) - # print torch.min(decoded) - - return decoded diff --git a/dewarp/models/unetnc.py b/dewarp/models/unetnc.py deleted file mode 100644 index b6c38dc..0000000 --- a/dewarp/models/unetnc.py +++ /dev/null @@ -1,89 +0,0 @@ -import functools - -import torch -import torch.nn as nn - - -# Defines the Unet generator. -# |num_downs|: number of downsamplings in UNet. For example, -# if |num_downs| == 7, image of size 128x128 will become of size 1x1 -# at the bottleneck -class UnetGenerator(nn.Module): - def __init__(self, input_nc, output_nc, num_downs, ngf=64, - norm_layer=nn.BatchNorm2d, use_dropout=False): - super(UnetGenerator, self).__init__() - - # construct unet structure - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, - innermost=True) - for i in range(num_downs - 5): - unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, - norm_layer=norm_layer, use_dropout=use_dropout) - unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, - norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, - norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) - unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, - norm_layer=norm_layer) - - self.model = unet_block - - def forward(self, input): - return self.model(input) - - -# Defines the submodule with skip connection. -# X -------------------identity---------------------- X -# |-- downsampling -- |submodule| -- upsampling --| -class UnetSkipConnectionBlock(nn.Module): - def __init__(self, outer_nc, inner_nc, input_nc=None, - submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): - super(UnetSkipConnectionBlock, self).__init__() - self.outermost = outermost - if type(norm_layer) == functools.partial: - use_bias = norm_layer.func == nn.InstanceNorm2d - else: - use_bias = norm_layer == nn.InstanceNorm2d - if input_nc is None: - input_nc = outer_nc - downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, - stride=2, padding=1, bias=use_bias) - downrelu = nn.LeakyReLU(0.2, True) - downnorm = norm_layer(inner_nc) - uprelu = nn.ReLU(True) - upnorm = norm_layer(outer_nc) - - if outermost: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1) - down = [downconv] - up = [uprelu, upconv, nn.Tanh()] - model = down + [submodule] + up - elif innermost: - upconv = nn.ConvTranspose2d(inner_nc, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv] - up = [uprelu, upconv, upnorm] - model = down + up - else: - upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, - kernel_size=4, stride=2, - padding=1, bias=use_bias) - down = [downrelu, downconv, downnorm] - up = [uprelu, upconv, upnorm] - - if use_dropout: - model = down + [submodule] + up + [nn.Dropout(0.5)] - else: - model = down + [submodule] + up - - self.model = nn.Sequential(*model) - - def forward(self, x): - if self.outermost: - return self.model(x) - else: - return torch.cat([x, self.model(x)], 1) diff --git a/dewarp/utils.py b/dewarp/utils.py deleted file mode 100644 index 65c0c4b..0000000 --- a/dewarp/utils.py +++ /dev/null @@ -1,17 +0,0 @@ -''' -Misc Utility functions -''' -from collections import OrderedDict - - -def convert_state_dict(state_dict): - """Converts a state dict saved from a dataParallel module to normal - module state_dict inplace - :param state_dict is the loaded DataParallel model_state - - """ - new_state_dict = OrderedDict() - for k, v in state_dict.items(): - name = k[7:] # remove `module.` - new_state_dict[name] = v - return new_state_dict diff --git a/doc_dewarp/.gitignore b/doc_dewarp/.gitignore new file mode 100644 index 0000000..7179701 --- /dev/null +++ b/doc_dewarp/.gitignore @@ -0,0 +1,167 @@ +input/ +output/ + +runs/ +*_dump/ +*_log/ + +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ diff --git a/doc_dewarp/.pre-commit-config.yaml b/doc_dewarp/.pre-commit-config.yaml new file mode 100644 index 0000000..07921f0 --- /dev/null +++ b/doc_dewarp/.pre-commit-config.yaml @@ -0,0 +1,34 @@ +repos: +# Common hooks +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.4.0 + hooks: + - id: check-added-large-files + - id: check-merge-conflict + - id: check-symlinks + - id: detect-private-key + - id: end-of-file-fixer + - id: trailing-whitespace +- repo: https://github.com/Lucas-C/pre-commit-hooks.git + rev: v1.5.1 + hooks: + - id: remove-crlf + - id: remove-tabs + name: Tabs remover (Python) + files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ + args: [--whitespaces-count, '4'] +# For Python files +- repo: https://github.com/psf/black.git + rev: 23.3.0 + hooks: + - id: black + files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$ +- repo: https://github.com/pycqa/isort + rev: 5.11.5 + hooks: + - id: isort +- repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.0.272 + hooks: + - id: ruff + args: [--fix, --exit-non-zero-on-fix, --no-cache] diff --git a/doc_dewarp/GeoTr.py b/doc_dewarp/GeoTr.py new file mode 100644 index 0000000..44cd01a --- /dev/null +++ b/doc_dewarp/GeoTr.py @@ -0,0 +1,398 @@ +import copy +from typing import Optional + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from .extractor import BasicEncoder +from .position_encoding import build_position_encoding +from .weight_init import weight_init_ + + +class attnLayer(nn.Layer): + def __init__( + self, + d_model, + nhead=8, + dim_feedforward=2048, + dropout=0.1, + activation="relu", + normalize_before=False, + ): + super().__init__() + + self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout) + self.multihead_attn_list = nn.LayerList( + [ + copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout)) + for i in range(2) + ] + ) + + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(p=dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model) + self.norm2_list = nn.LayerList( + [copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)] + ) + + self.norm3 = nn.LayerNorm(d_model) + self.dropout1 = nn.Dropout(p=dropout) + self.dropout2_list = nn.LayerList( + [copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)] + ) + self.dropout3 = nn.Dropout(p=dropout) + + self.activation = _get_activation_fn(activation) + self.normalize_before = normalize_before + + def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]): + return tensor if pos is None else tensor + pos + + def forward_post( + self, + tgt, + memory_list, + tgt_mask=None, + memory_mask=None, + pos=None, + memory_pos=None, + ): + q = k = self.with_pos_embed(tgt, pos) + tgt2 = self.self_attn( + q.transpose((1, 0, 2)), + k.transpose((1, 0, 2)), + value=tgt.transpose((1, 0, 2)), + attn_mask=tgt_mask, + ) + tgt2 = tgt2.transpose((1, 0, 2)) + tgt = tgt + self.dropout1(tgt2) + tgt = self.norm1(tgt) + + for memory, multihead_attn, norm2, dropout2, m_pos in zip( + memory_list, + self.multihead_attn_list, + self.norm2_list, + self.dropout2_list, + memory_pos, + ): + tgt2 = multihead_attn( + query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)), + key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)), + value=memory.transpose((1, 0, 2)), + attn_mask=memory_mask, + ).transpose((1, 0, 2)) + + tgt = tgt + dropout2(tgt2) + tgt = norm2(tgt) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt)))) + tgt = tgt + self.dropout3(tgt2) + tgt = self.norm3(tgt) + + return tgt + + def forward_pre( + self, + tgt, + memory, + tgt_mask=None, + memory_mask=None, + pos=None, + memory_pos=None, + ): + tgt2 = self.norm1(tgt) + + q = k = self.with_pos_embed(tgt2, pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask) + tgt = tgt + self.dropout1(tgt2) + tgt2 = self.norm2(tgt) + + tgt2 = self.multihead_attn( + query=self.with_pos_embed(tgt2, pos), + key=self.with_pos_embed(memory, memory_pos), + value=memory, + attn_mask=memory_mask, + ) + tgt = tgt + self.dropout2(tgt2) + tgt2 = self.norm3(tgt) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout3(tgt2) + + return tgt + + def forward( + self, + tgt, + memory_list, + tgt_mask=None, + memory_mask=None, + pos=None, + memory_pos=None, + ): + if self.normalize_before: + return self.forward_pre( + tgt, + memory_list, + tgt_mask, + memory_mask, + pos, + memory_pos, + ) + return self.forward_post( + tgt, + memory_list, + tgt_mask, + memory_mask, + pos, + memory_pos, + ) + + +def _get_clones(module, N): + return nn.LayerList([copy.deepcopy(module) for i in range(N)]) + + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(f"activation should be relu/gelu, not {activation}.") + + +class TransDecoder(nn.Layer): + def __init__(self, num_attn_layers: int, hidden_dim: int = 128): + super(TransDecoder, self).__init__() + + attn_layer = attnLayer(hidden_dim) + self.layers = _get_clones(attn_layer, num_attn_layers) + self.position_embedding = build_position_encoding(hidden_dim) + + def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor): + pos = self.position_embedding( + paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool") + ) + + b, c, h, w = image.shape + + image = image.flatten(2).transpose(perm=[2, 0, 1]) + pos = pos.flatten(2).transpose(perm=[2, 0, 1]) + + for layer in self.layers: + query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos]) + + query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w]) + + return query_embed + + +class TransEncoder(nn.Layer): + def __init__(self, num_attn_layers: int, hidden_dim: int = 128): + super(TransEncoder, self).__init__() + + attn_layer = attnLayer(hidden_dim) + self.layers = _get_clones(attn_layer, num_attn_layers) + self.position_embedding = build_position_encoding(hidden_dim) + + def forward(self, image: paddle.Tensor): + pos = self.position_embedding( + paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool") + ) + + b, c, h, w = image.shape + + image = image.flatten(2).transpose(perm=[2, 0, 1]) + pos = pos.flatten(2).transpose(perm=[2, 0, 1]) + + for layer in self.layers: + image = layer(image, [image], pos=pos, memory_pos=[pos, pos]) + + image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w]) + + return image + + +class FlowHead(nn.Layer): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + + self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU() + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + + +class UpdateBlock(nn.Layer): + def __init__(self, hidden_dim: int = 128): + super(UpdateBlock, self).__init__() + + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2D(hidden_dim, 256, 3, padding=1), + nn.ReLU(), + nn.Conv2D(256, 64 * 9, 1, padding=0), + ) + + def forward(self, image, coords): + mask = 0.25 * self.mask(image) + dflow = self.flow_head(image) + coords = coords + dflow + return mask, coords + + +def coords_grid(batch, ht, wd): + coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd)) + coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32") + return coords[None].tile([batch, 1, 1, 1]) + + +def upflow8(flow, mode="bilinear"): + new_size = 8 * flow.shape[2], 8 * flow.shape[3] + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) + + +class OverlapPatchEmbed(nn.Layer): + """Image to Patch Embedding""" + + def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768): + super().__init__() + + img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size) + patch_size = ( + patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size) + ) + + self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1] + self.num_patches = self.H * self.W + + self.proj = nn.Conv2D( + in_chans, + embed_dim, + patch_size, + stride=stride, + padding=(patch_size[0] // 2, patch_size[1] // 2), + ) + self.norm = nn.LayerNorm(embed_dim) + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + weight_init_(m, "trunc_normal_", std=0.02) + elif isinstance(m, nn.LayerNorm): + weight_init_(m, "Constant", value=1.0) + elif isinstance(m, nn.Conv2D): + weight_init_( + m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu" + ) + + def forward(self, x): + x = self.proj(x) + _, _, H, W = x.shape + x = x.flatten(2) + + perm = list(range(x.ndim)) + perm[1] = 2 + perm[2] = 1 + x = x.transpose(perm=perm) + + x = self.norm(x) + + return x, H, W + + +class GeoTr(nn.Layer): + def __init__(self): + super(GeoTr, self).__init__() + + self.hidden_dim = hdim = 256 + + self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance") + + self.encoder_block = [("encoder_block" + str(i)) for i in range(3)] + for i in self.encoder_block: + self.__setattr__(i, TransEncoder(2, hidden_dim=hdim)) + + self.down_layer = [("down_layer" + str(i)) for i in range(2)] + for i in self.down_layer: + self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1)) + + self.decoder_block = [("decoder_block" + str(i)) for i in range(3)] + for i in self.decoder_block: + self.__setattr__(i, TransDecoder(2, hidden_dim=hdim)) + + self.up_layer = [("up_layer" + str(i)) for i in range(2)] + for i in self.up_layer: + self.__setattr__( + i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) + ) + + self.query_embed = nn.Embedding(81, self.hidden_dim) + + self.update_block = UpdateBlock(self.hidden_dim) + + def initialize_flow(self, img): + N, _, H, W = img.shape + coodslar = coords_grid(N, H, W) + coords0 = coords_grid(N, H // 8, W // 8) + coords1 = coords_grid(N, H // 8, W // 8) + return coodslar, coords0, coords1 + + def upsample_flow(self, flow, mask): + N, _, H, W = flow.shape + + mask = mask.reshape([N, 1, 9, 8, 8, H, W]) + mask = F.softmax(mask, axis=2) + + up_flow = F.unfold(8 * flow, [3, 3], paddings=1) + up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W]) + + up_flow = paddle.sum(mask * up_flow, axis=2) + up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3]) + + return up_flow.reshape([N, 2, 8 * H, 8 * W]) + + def forward(self, image): + fmap = self.fnet(image) + fmap = F.relu(fmap) + + fmap1 = self.__getattr__(self.encoder_block[0])(fmap) + fmap1d = self.__getattr__(self.down_layer[0])(fmap1) + fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d) + fmap2d = self.__getattr__(self.down_layer[1])(fmap2) + fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d) + + query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1]) + + fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0) + fmap3du_ = ( + self.__getattr__(self.up_layer[0])(fmap3d_) + .flatten(2) + .transpose(perm=[2, 0, 1]) + ) + fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_) + fmap2du_ = ( + self.__getattr__(self.up_layer[1])(fmap2d_) + .flatten(2) + .transpose(perm=[2, 0, 1]) + ) + fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_) + + coodslar, coords0, coords1 = self.initialize_flow(image) + coords1 = coords1.detach() + mask, coords1 = self.update_block(fmap_out, coords1) + flow_up = self.upsample_flow(coords1 - coords0, mask) + bm_up = coodslar + flow_up + + return bm_up diff --git a/doc_dewarp/README.md b/doc_dewarp/README.md new file mode 100644 index 0000000..60127c2 --- /dev/null +++ b/doc_dewarp/README.md @@ -0,0 +1,73 @@ +# DocTrPP: DocTr++ in PaddlePaddle + +## Introduction + +This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus). + +![demo](https://github.com/GreatV/DocTrPP/assets/17264618/4e491512-bfc4-4e69-a833-fd1c6e17158c) + +## Requirements + +You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/). + +## Training + +1. Data Preparation + +To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset). + +2. Training + +```shell +sh train.sh +``` + +or + +```shell +export OPENCV_IO_ENABLE_OPENEXR=1 +export CUDA_VISIBLE_DEVICES=0 + +python train.py --img-size 288 \ + --name "DocTr++" \ + --batch-size 12 \ + --lr 2.5e-5 \ + --exist-ok \ + --use-vdl +``` + +3. Load Trained Model and Continue Training + +```shell +export OPENCV_IO_ENABLE_OPENEXR=1 +export CUDA_VISIBLE_DEVICES=0 + +python train.py --img-size 288 \ + --name "DocTr++" \ + --batch-size 12 \ + --lr 2.5e-5 \ + --resume "runs/train/DocTr++/weights/last.ckpt" \ + --exist-ok \ + --use-vdl +``` + +## Test and Inference + +Test the dewarp result on a single image: + +```shell +python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png +``` +![document image rectification](https://raw.githubusercontent.com/greatv/DocTrPP/main/doc/imgs/document_image_rectification.jpg) + +## Export to onnx + +``` +pip install paddle2onnx + +python export.py -m ./best.ckpt --format onnx +``` + +## Model Download + +The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt). diff --git a/dewarp/__init__.py b/doc_dewarp/__init__.py similarity index 100% rename from dewarp/__init__.py rename to doc_dewarp/__init__.py diff --git a/doc_dewarp/data_visualization.py b/doc_dewarp/data_visualization.py new file mode 100644 index 0000000..ad716ad --- /dev/null +++ b/doc_dewarp/data_visualization.py @@ -0,0 +1,133 @@ +import argparse +import os +import random + +import cv2 +import hdf5storage as h5 +import matplotlib.pyplot as plt +import numpy as np +import paddle +import paddle.nn.functional as F + +os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1" + +parser = argparse.ArgumentParser() +parser.add_argument( + "--data_root", + nargs="?", + type=str, + default="~/datasets/doc3d/", + help="Path to the downloaded dataset", +) +parser.add_argument( + "--folder", nargs="?", type=int, default=1, help="Folder ID to read from" +) +parser.add_argument( + "--output", + nargs="?", + type=str, + default="output.png", + help="Output filename for the image", +) + +args = parser.parse_args() + +root = os.path.expanduser(args.data_root) +folder = args.folder +dirname = os.path.join(root, "img", str(folder)) + +choices = [f for f in os.listdir(dirname) if "png" in f] +fname = random.choice(choices) + +# Read Image +img_path = os.path.join(dirname, fname) +img = cv2.imread(img_path) + +# Read 3D Coords +wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr") +wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) +# scale wc +# value obtained from the entire dataset +xmx, xmn, ymx, ymn, zmx, zmn = ( + 1.2539363, + -1.2442188, + 1.2396319, + -1.2289206, + 0.6436657, + -0.67492497, +) +wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn) +wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn) +wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn) + +# Read Backward Map +bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat") +bm = h5.loadmat(bm_path)["bm"] + +# Read UV Map +uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr") +uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + +# Read Depth Map +dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr") +dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0] +# do some clipping and scaling to display it +dmap[dmap > 30.0] = 30 +dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap))) + +# Read Normal Map +norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr") +norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + +# Read Albedo +alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png") +alb = cv2.imread(alb_path) + +# Read Checkerboard Image +recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png") +recon = cv2.imread(recon_path) + +# Display image and GTs + +# use the backward mapping to dewarp the image +# scale bm to -1.0 to 1.0 +bm_ = bm / np.array([448, 448]) +bm_ = (bm_ - 0.5) * 2 +bm_ = np.reshape(bm_, (1, 448, 448, 2)) +bm_ = paddle.to_tensor(bm_, dtype="float32") +img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0 +img_ = np.expand_dims(img_, 0) +img_ = paddle.to_tensor(img_, dtype="float32") +uw = F.grid_sample(img_, bm_) +uw = uw[0].numpy().transpose((1, 2, 0)) + +f, axrr = plt.subplots(2, 5) +for ax in axrr: + for a in ax: + a.set_xticks([]) + a.set_yticks([]) + +axrr[0][0].imshow(img) +axrr[0][0].title.set_text("image") +axrr[0][1].imshow(wc) +axrr[0][1].title.set_text("3D coords") +axrr[0][2].imshow(bm[:, :, 0]) +axrr[0][2].title.set_text("bm 0") +axrr[0][3].imshow(bm[:, :, 1]) +axrr[0][3].title.set_text("bm 1") +if uv is None: + uv = np.zeros_like(img) +axrr[0][4].imshow(uv) +axrr[0][4].title.set_text("uv map") +axrr[1][0].imshow(dmap) +axrr[1][0].title.set_text("depth map") +axrr[1][1].imshow(norm) +axrr[1][1].title.set_text("normal map") +axrr[1][2].imshow(alb) +axrr[1][2].title.set_text("albedo") +axrr[1][3].imshow(recon) +axrr[1][3].title.set_text("checkerboard") +axrr[1][4].imshow(uw) +axrr[1][4].title.set_text("gt unwarped") +plt.tight_layout() +plt.savefig(args.output) diff --git a/doc_dewarp/dewarp.py b/doc_dewarp/dewarp.py new file mode 100644 index 0000000..1decad9 --- /dev/null +++ b/doc_dewarp/dewarp.py @@ -0,0 +1,26 @@ +import cv2 +import paddle + +from .GeoTr import GeoTr +from .utils import to_tensor, to_image + + +def dewarp_image(image): + model_path = "model/dewarp_model/best.ckpt" + + checkpoint = paddle.load(model_path) + state_dict = checkpoint["model"] + model = GeoTr() + model.set_state_dict(state_dict) + model.eval() + + img = cv2.resize(image, (288, 288)) + x = to_tensor(img) + y = to_tensor(image) + bm = model(x) + bm = paddle.nn.functional.interpolate( + bm, y.shape[2:], mode="bilinear", align_corners=False + ) + bm_nhwc = bm.transpose([0, 2, 3, 1]) + out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2) + return to_image(out) diff --git a/doc_dewarp/doc/download_dataset.sh b/doc_dewarp/doc/download_dataset.sh new file mode 100644 index 0000000..7afe858 --- /dev/null +++ b/doc_dewarp/doc/download_dataset.sh @@ -0,0 +1,161 @@ +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip" + +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip" +wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip" diff --git a/doc_dewarp/doc/imgs/document_image_rectification.jpg b/doc_dewarp/doc/imgs/document_image_rectification.jpg new file mode 100644 index 0000000..e5a8b26 Binary files /dev/null and b/doc_dewarp/doc/imgs/document_image_rectification.jpg differ diff --git a/doc_dewarp/doc3d_dataset.py b/doc_dewarp/doc3d_dataset.py new file mode 100644 index 0000000..49f64f4 --- /dev/null +++ b/doc_dewarp/doc3d_dataset.py @@ -0,0 +1,129 @@ +import collections +import os +import random + +import albumentations as A +import cv2 +import hdf5storage as h5 +import numpy as np +import paddle +from paddle import io + +# Set random seed +random.seed(12345678) + + +class Doc3dDataset(io.Dataset): + def __init__(self, root, split="train", is_augment=False, image_size=512): + self.root = os.path.expanduser(root) + + self.split = split + self.is_augment = is_augment + + self.files = collections.defaultdict(list) + + self.image_size = ( + image_size if isinstance(image_size, tuple) else (image_size, image_size) + ) + + # Augmentation + self.augmentation = A.Compose( + [ + A.ColorJitter(), + ] + ) + + for split in ["train", "val"]: + path = os.path.join(self.root, split + ".txt") + file_list = [] + with open(path, "r") as file: + file_list = [file_id.rstrip() for file_id in file.readlines()] + self.files[split] = file_list + + def __len__(self): + return len(self.files[self.split]) + + def __getitem__(self, index): + image_name = self.files[self.split][index] + + # Read image + image_path = os.path.join(self.root, "img", image_name + ".png") + image = cv2.imread(image_path) + + # Read 3D Coordinates + wc_path = os.path.join(self.root, "wc", image_name + ".exr") + wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) + + # Read backward map + bm_path = os.path.join(self.root, "bm", image_name + ".mat") + bm = h5.loadmat(bm_path)["bm"] + + image, bm = self.transform(wc, bm, image) + + return image, bm + + def tight_crop(self, wc: np.ndarray): + mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype( + np.uint8 + ) + mask_size = mask.shape + [y, x] = mask.nonzero() + min_x = min(x) + max_x = max(x) + min_y = min(y) + max_y = max(y) + + wc = wc[min_y : max_y + 1, min_x : max_x + 1, :] + s = 10 + wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant") + + cx1 = random.randint(0, 2 * s) + cx2 = random.randint(0, 2 * s) + 1 + cy1 = random.randint(0, 2 * s) + cy2 = random.randint(0, 2 * s) + 1 + + wc = wc[cy1:-cy2, cx1:-cx2, :] + + top: int = min_y - s + cy1 + bottom: int = mask_size[0] - max_y - s + cy2 + left: int = min_x - s + cx1 + right: int = mask_size[1] - max_x - s + cx2 + + top = np.clip(top, 0, mask_size[0]) + bottom = np.clip(bottom, 1, mask_size[0] - 1) + left = np.clip(left, 0, mask_size[1]) + right = np.clip(right, 1, mask_size[1] - 1) + + return wc, top, bottom, left, right + + def transform(self, wc, bm, img): + wc, top, bottom, left, right = self.tight_crop(wc) + + img = img[top:-bottom, left:-right, :] + + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + if self.is_augment: + img = self.augmentation(image=img)["image"] + + # resize image + img = cv2.resize(img, self.image_size) + img = img.astype(np.float32) / 255.0 + img = img.transpose(2, 0, 1) + + # resize bm + bm = bm.astype(np.float32) + bm[:, :, 1] = bm[:, :, 1] - top + bm[:, :, 0] = bm[:, :, 0] - left + bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom]) + bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1])) + bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1])) + bm0 = bm0 * self.image_size[0] + bm1 = bm1 * self.image_size[1] + + bm = np.stack([bm0, bm1], axis=-1) + + img = paddle.to_tensor(img).astype(dtype="float32") + bm = paddle.to_tensor(bm).astype(dtype="float32") + + return img, bm diff --git a/doc_dewarp/export.py b/doc_dewarp/export.py new file mode 100644 index 0000000..5b7e5d8 --- /dev/null +++ b/doc_dewarp/export.py @@ -0,0 +1,66 @@ +import argparse +import os + +import paddle + +from GeoTr import GeoTr + + +def export(args): + model_path = args.model + imgsz = args.imgsz + format = args.format + + model = GeoTr() + checkpoint = paddle.load(model_path) + model.set_state_dict(checkpoint["model"]) + model.eval() + + dirname = os.path.dirname(model_path) + if format == "static" or format == "onnx": + model = paddle.jit.to_static( + model, + input_spec=[ + paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32") + ], + full_graph=True, + ) + paddle.jit.save(model, os.path.join(dirname, "model")) + + if format == "onnx": + onnx_path = os.path.join(dirname, "model.onnx") + os.system( + f"paddle2onnx --model_dir {dirname}" + " --model_filename model.pdmodel" + " --params_filename model.pdiparams" + f" --save_file {onnx_path}" + " --opset_version 11" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="export model") + + parser.add_argument( + "--model", + "-m", + nargs="?", + type=str, + default="", + help="The path of model", + ) + + parser.add_argument( + "--imgsz", type=int, default=288, help="The size of input image" + ) + + parser.add_argument( + "--format", + type=str, + default="static", + help="The format of exported model, which can be static or onnx", + ) + + args = parser.parse_args() + + export(args) diff --git a/doc_dewarp/extractor.py b/doc_dewarp/extractor.py new file mode 100644 index 0000000..5f3bace --- /dev/null +++ b/doc_dewarp/extractor.py @@ -0,0 +1,110 @@ +import paddle.nn as nn + +from .weight_init import weight_init_ + + +class ResidualBlock(nn.Layer): + """Residual Block with custom normalization.""" + + def __init__(self, in_planes, planes, norm_fn="group", stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride) + self.conv2 = nn.Conv2D(planes, planes, 3, padding=1) + self.relu = nn.ReLU() + + if norm_fn == "group": + num_groups = planes // 8 + self.norm1 = nn.GroupNorm(num_groups, planes) + self.norm2 = nn.GroupNorm(num_groups, planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups, planes) + elif norm_fn == "batch": + self.norm1 = nn.BatchNorm2D(planes) + self.norm2 = nn.BatchNorm2D(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2D(planes) + elif norm_fn == "instance": + self.norm1 = nn.InstanceNorm2D(planes) + self.norm2 = nn.InstanceNorm2D(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2D(planes) + elif norm_fn == "none": + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + else: + self.downsample = nn.Sequential( + nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3 + ) + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x + y) + + +class BasicEncoder(nn.Layer): + """Basic Encoder with custom normalization.""" + + def __init__(self, output_dim=128, norm_fn="batch"): + super(BasicEncoder, self).__init__() + + self.norm_fn = norm_fn + + if self.norm_fn == "group": + self.norm1 = nn.GroupNorm(8, 64) + elif self.norm_fn == "batch": + self.norm1 = nn.BatchNorm2D(64) + elif self.norm_fn == "instance": + self.norm1 = nn.InstanceNorm2D(64) + elif self.norm_fn == "none": + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3) + self.relu1 = nn.ReLU() + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(128, stride=2) + self.layer3 = self._make_layer(192, stride=2) + + self.conv2 = nn.Conv2D(192, output_dim, 1) + + for m in self.sublayers(): + if isinstance(m, nn.Conv2D): + weight_init_( + m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu" + ) + elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)): + weight_init_(m, "Constant", value=1, bias_value=0.0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = layer1, layer2 + + self.in_planes = dim + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + return x diff --git a/doc_dewarp/plots.py b/doc_dewarp/plots.py new file mode 100644 index 0000000..fa384e1 --- /dev/null +++ b/doc_dewarp/plots.py @@ -0,0 +1,48 @@ +import copy +import os + +import matplotlib.pyplot as plt +import paddle.optimizer as optim + +from GeoTr import GeoTr + + +def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""): + """ + Plot the learning rate scheduler + """ + + optimizer = copy.copy(optimizer) + scheduler = copy.copy(scheduler) + + lr = [] + for _ in range(epochs): + for _ in range(30): + lr.append(scheduler.get_lr()) + optimizer.step() + scheduler.step() + + epoch = [float(i) / 30.0 for i in range(len(lr))] + + plt.figure() + plt.plot(epoch, lr, ".-", label="Learning Rate") + plt.xlabel("epoch") + plt.ylabel("Learning Rate") + plt.title("Learning Rate Scheduler") + plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300) + plt.close() + + +if __name__ == "__main__": + model = GeoTr() + + schaduler = optim.lr.OneCycleLR( + max_learning_rate=1e-4, + total_steps=1950, + phase_pct=0.1, + end_learning_rate=1e-4 / 2.5e5, + ) + optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters()) + plot_lr_scheduler( + scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./" + ) diff --git a/doc_dewarp/position_encoding.py b/doc_dewarp/position_encoding.py new file mode 100644 index 0000000..327f7ea --- /dev/null +++ b/doc_dewarp/position_encoding.py @@ -0,0 +1,124 @@ +import math +from typing import Optional + +import paddle +import paddle.nn as nn +import paddle.nn.initializer as init + + +class NestedTensor(object): + def __init__(self, tensors, mask: Optional[paddle.Tensor]): + self.tensors = tensors + self.mask = mask + + def decompose(self): + return self.tensors, self.mask + + def __repr__(self): + return str(self.tensors) + + +class PositionEmbeddingSine(nn.Layer): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + + if scale is None: + scale = 2 * math.pi + + self.scale = scale + + def forward(self, mask): + assert mask is not None + + y_embed = mask.cumsum(axis=1, dtype="float32") + x_embed = mask.cumsum(axis=2, dtype="float32") + + if self.normalize: + eps = 1e-06 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32") + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + + pos_x = paddle.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4 + ).flatten(3) + pos_y = paddle.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4 + ).flatten(3) + + pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2]) + + return pos + + +class PositionEmbeddingLearned(nn.Layer): + """ + Absolute pos embedding, learned. + """ + + def __init__(self, num_pos_feats=256): + super().__init__() + self.row_embed = nn.Embedding(50, num_pos_feats) + self.col_embed = nn.Embedding(50, num_pos_feats) + self.reset_parameters() + + def reset_parameters(self): + init_Constant = init.Uniform() + init_Constant(self.row_embed.weight) + init_Constant(self.col_embed.weight) + + def forward(self, tensor_list: NestedTensor): + x = tensor_list.tensors + + h, w = x.shape[-2:] + + i = paddle.arange(end=w) + j = paddle.arange(end=h) + + x_emb = self.col_embed(i) + y_emb = self.row_embed(j) + + pos = ( + paddle.concat( + [ + x_emb.unsqueeze(0).tile([h, 1, 1]), + y_emb.unsqueeze(1).tile([1, w, 1]), + ], + axis=-1, + ) + .transpose([2, 0, 1]) + .unsqueeze(0) + .tile([x.shape[0], 1, 1, 1]) + ) + + return pos + + +def build_position_encoding(hidden_dim=512, position_embedding="sine"): + N_steps = hidden_dim // 2 + + if position_embedding in ("v2", "sine"): + position_embedding = PositionEmbeddingSine(N_steps, normalize=True) + elif position_embedding in ("v3", "learned"): + position_embedding = PositionEmbeddingLearned(N_steps) + else: + raise ValueError(f"not supported {position_embedding}") + return position_embedding diff --git a/doc_dewarp/predict.py b/doc_dewarp/predict.py new file mode 100644 index 0000000..bed8c63 --- /dev/null +++ b/doc_dewarp/predict.py @@ -0,0 +1,69 @@ +import argparse + +import cv2 +import paddle + +from GeoTr import GeoTr +from utils import to_image, to_tensor + + +def run(args): + image_path = args.image + model_path = args.model + output_path = args.output + + checkpoint = paddle.load(model_path) + state_dict = checkpoint["model"] + model = GeoTr() + model.set_state_dict(state_dict) + model.eval() + + img_org = cv2.imread(image_path) + img = cv2.resize(img_org, (288, 288)) + x = to_tensor(img) + y = to_tensor(img_org) + bm = model(x) + bm = paddle.nn.functional.interpolate( + bm, y.shape[2:], mode="bilinear", align_corners=False + ) + bm_nhwc = bm.transpose([0, 2, 3, 1]) + out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2) + out_image = to_image(out) + cv2.imwrite(output_path, out_image) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="predict") + + parser.add_argument( + "--image", + "-i", + nargs="?", + type=str, + default="", + help="The path of image", + ) + + parser.add_argument( + "--model", + "-m", + nargs="?", + type=str, + default="", + help="The path of model", + ) + + parser.add_argument( + "--output", + "-o", + nargs="?", + type=str, + default="", + help="The path of output", + ) + + args = parser.parse_args() + + print(args) + + run(args) diff --git a/doc_dewarp/requirements.txt b/doc_dewarp/requirements.txt new file mode 100644 index 0000000..102d164 --- /dev/null +++ b/doc_dewarp/requirements.txt @@ -0,0 +1,7 @@ +hdf5storage +loguru +numpy +scipy +opencv-python +matplotlib +albumentations diff --git a/doc_dewarp/split_dataset.py b/doc_dewarp/split_dataset.py new file mode 100644 index 0000000..6f83b3e --- /dev/null +++ b/doc_dewarp/split_dataset.py @@ -0,0 +1,57 @@ +import argparse +import glob +import os +import random +from pathlib import Path + +from loguru import logger + +random.seed(1234567) + + +def run(args): + data_root = os.path.expanduser(args.data_root) + ratio = args.train_ratio + + data_path = os.path.join(data_root, "img", "*", "*.png") + img_list = glob.glob(data_path, recursive=True) + sorted(img_list) + random.shuffle(img_list) + + train_size = int(len(img_list) * ratio) + + train_text_path = os.path.join(data_root, "train.txt") + with open(train_text_path, "w") as file: + for item in img_list[:train_size]: + parts = Path(item).parts + item = os.path.join(parts[-2], parts[-1]) + file.write("%s\n" % item.split(".png")[0]) + + val_text_path = os.path.join(data_root, "val.txt") + with open(val_text_path, "w") as file: + for item in img_list[train_size:]: + parts = Path(item).parts + item = os.path.join(parts[-2], parts[-1]) + file.write("%s\n" % item.split(".png")[0]) + + logger.info(f"TRAIN LABEL: {train_text_path}") + logger.info(f"VAL LABEL: {val_text_path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--data_root", + type=str, + default="~/datasets/doc3d", + help="Data path to load data", + ) + parser.add_argument( + "--train_ratio", type=float, default=0.8, help="Ratio of training data" + ) + + args = parser.parse_args() + + logger.info(args) + + run(args) diff --git a/doc_dewarp/train.py b/doc_dewarp/train.py new file mode 100644 index 0000000..264b4b9 --- /dev/null +++ b/doc_dewarp/train.py @@ -0,0 +1,381 @@ +import argparse +import inspect +import os +import random +from pathlib import Path +from typing import Optional + +import numpy as np +import paddle +import paddle.nn as nn +import paddle.optimizer as optim +from loguru import logger +from paddle.io import DataLoader +from paddle.nn import functional as F +from paddle_msssim import ms_ssim, ssim + +from doc3d_dataset import Doc3dDataset +from GeoTr import GeoTr +from utils import to_image + +FILE = Path(__file__).resolve() +ROOT = FILE.parents[0] +RANK = int(os.getenv("RANK", -1)) + + +def init_seeds(seed=0, deterministic=False): + random.seed(seed) + np.random.seed(seed) + paddle.seed(seed) + if deterministic: + os.environ["FLAGS_cudnn_deterministic"] = "1" + os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" + os.environ["PYTHONHASHSEED"] = str(seed) + + +def colorstr(*input): + # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, + # i.e. colorstr('blue', 'hello world') + + *args, string = ( + input if len(input) > 1 else ("blue", "bold", input[0]) + ) # color arguments, string + + colors = { + "black": "\033[30m", # basic colors + "red": "\033[31m", + "green": "\033[32m", + "yellow": "\033[33m", + "blue": "\033[34m", + "magenta": "\033[35m", + "cyan": "\033[36m", + "white": "\033[37m", + "bright_black": "\033[90m", # bright colors + "bright_red": "\033[91m", + "bright_green": "\033[92m", + "bright_yellow": "\033[93m", + "bright_blue": "\033[94m", + "bright_magenta": "\033[95m", + "bright_cyan": "\033[96m", + "bright_white": "\033[97m", + "end": "\033[0m", # misc + "bold": "\033[1m", + "underline": "\033[4m", + } + + return "".join(colors[x] for x in args) + f"{string}" + colors["end"] + + +def print_args(args: Optional[dict] = None, show_file=True, show_func=False): + # Print function arguments (optional args dict) + x = inspect.currentframe().f_back # previous frame + file, _, func, _, _ = inspect.getframeinfo(x) + if args is None: # get args automatically + args, _, _, frm = inspect.getargvalues(x) + args = {k: v for k, v in frm.items() if k in args} + try: + file = Path(file).resolve().relative_to(ROOT).with_suffix("") + except ValueError: + file = Path(file).stem + s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "") + logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items())) + + +def increment_path(path, exist_ok=False, sep="", mkdir=False): + # Increment file or directory path, + # i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc. + path = Path(path) # os-agnostic + if path.exists() and not exist_ok: + path, suffix = ( + (path.with_suffix(""), path.suffix) if path.is_file() else (path, "") + ) + + for n in range(2, 9999): + p = f"{path}{sep}{n}{suffix}" # increment path + if not os.path.exists(p): + break + path = Path(p) + + if mkdir: + path.mkdir(parents=True, exist_ok=True) # make directory + + return path + + +def train(args): + save_dir = Path(args.save_dir) + + use_vdl = args.use_vdl + if use_vdl: + from visualdl import LogWriter + + log_dir = save_dir / "vdl" + vdl_writer = LogWriter(str(log_dir)) + + # Directories + weights_dir = save_dir / "weights" + weights_dir.parent.mkdir(parents=True, exist_ok=True) + + last = weights_dir / "last.ckpt" + best = weights_dir / "best.ckpt" + + # Hyperparameters + + # Config + init_seeds(args.seed) + + # Train loader + train_dataset = Doc3dDataset( + args.data_root, + split="train", + is_augment=True, + image_size=args.img_size, + ) + + train_loader = DataLoader( + train_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=args.workers, + ) + + # Validation loader + val_dataset = Doc3dDataset( + args.data_root, + split="val", + is_augment=False, + image_size=args.img_size, + ) + + val_loader = DataLoader( + val_dataset, batch_size=args.batch_size, num_workers=args.workers + ) + + # Model + model = GeoTr() + + if use_vdl: + vdl_writer.add_graph( + model, + input_spec=[ + paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32") + ], + ) + + # Data Parallel Mode + if RANK == -1 and paddle.device.cuda.device_count() > 1: + model = paddle.DataParallel(model) + + # Scheduler + scheduler = optim.lr.OneCycleLR( + max_learning_rate=args.lr, + total_steps=args.epochs * len(train_loader), + phase_pct=0.1, + end_learning_rate=args.lr / 2.5e5, + ) + + # Optimizer + optimizer = optim.AdamW( + learning_rate=scheduler, + parameters=model.parameters(), + ) + + # loss function + l1_loss_fn = nn.L1Loss() + mse_loss_fn = nn.MSELoss() + + # Resume + best_fitness, start_epoch = 0.0, 0 + if args.resume: + ckpt = paddle.load(args.resume) + model.set_state_dict(ckpt["model"]) + optimizer.set_state_dict(ckpt["optimizer"]) + scheduler.set_state_dict(ckpt["scheduler"]) + best_fitness = ckpt["best_fitness"] + start_epoch = ckpt["epoch"] + 1 + + # Train + for epoch in range(start_epoch, args.epochs): + model.train() + + for i, (img, target) in enumerate(train_loader): + img = paddle.to_tensor(img) # NCHW + target = paddle.to_tensor(target) # NHWC + + pred = model(img) # NCHW + pred_nhwc = pred.transpose([0, 2, 3, 1]) + + loss = l1_loss_fn(pred_nhwc, target) + mse_loss = mse_loss_fn(pred_nhwc, target) + + if use_vdl: + vdl_writer.add_scalar( + "Train/L1 Loss", float(loss), epoch * len(train_loader) + i + ) + vdl_writer.add_scalar( + "Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i + ) + vdl_writer.add_scalar( + "Train/Learning Rate", + float(scheduler.get_lr()), + epoch * len(train_loader) + i, + ) + + loss.backward() + optimizer.step() + scheduler.step() + optimizer.clear_grad() + + if i % 10 == 0: + logger.info( + f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, " + f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}" + ) + + # Validation + model.eval() + + with paddle.no_grad(): + avg_ssim = paddle.zeros([]) + avg_ms_ssim = paddle.zeros([]) + avg_l1_loss = paddle.zeros([]) + avg_mse_loss = paddle.zeros([]) + + for i, (img, target) in enumerate(val_loader): + img = paddle.to_tensor(img) + target = paddle.to_tensor(target) + + pred = model(img) + pred_nhwc = pred.transpose([0, 2, 3, 1]) + + # predict image + out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2) + out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2) + + # calculate ssim + ssim_val = ssim(out, out_gt, data_range=1.0) + ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0) + + loss = l1_loss_fn(pred_nhwc, target) + mse_loss = mse_loss_fn(pred_nhwc, target) + + # calculate fitness + avg_ssim += ssim_val + avg_ms_ssim += ms_ssim_val + avg_l1_loss += loss + avg_mse_loss += mse_loss + + if i % 10 == 0: + logger.info( + f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, " + f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, " + f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}" + ) + + if use_vdl and i == 0: + img_0 = to_image(out[0]) + img_gt_0 = to_image(out_gt[0]) + vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch) + vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch) + + img_1 = to_image(out[1]) + img_gt_1 = to_image(out_gt[1]) + img_gt_1 = img_gt_1.astype("uint8") + vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch) + vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch) + + img_2 = to_image(out[2]) + img_gt_2 = to_image(out_gt[2]) + vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch) + vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch) + + avg_ssim /= len(val_loader) + avg_ms_ssim /= len(val_loader) + avg_l1_loss /= len(val_loader) + avg_mse_loss /= len(val_loader) + + if use_vdl: + vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch) + vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch) + vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch) + vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch) + + # Save + ckpt = { + "model": model.state_dict(), + "optimizer": optimizer.state_dict(), + "scheduler": scheduler.state_dict(), + "best_fitness": best_fitness, + "epoch": epoch, + } + + paddle.save(ckpt, str(last)) + + if best_fitness < avg_ssim: + best_fitness = avg_ssim + paddle.save(ckpt, str(best)) + + if use_vdl: + vdl_writer.close() + + +def main(args): + print_args(vars(args)) + + args.save_dir = str( + increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok) + ) + + train(args) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Hyperparams") + parser.add_argument( + "--data-root", + nargs="?", + type=str, + default="~/datasets/doc3d", + help="The root path of the dataset", + ) + parser.add_argument( + "--img-size", + nargs="?", + type=int, + default=288, + help="The size of the input image", + ) + parser.add_argument( + "--epochs", + nargs="?", + type=int, + default=65, + help="The number of training epochs", + ) + parser.add_argument( + "--batch-size", nargs="?", type=int, default=12, help="Batch Size" + ) + parser.add_argument( + "--lr", nargs="?", type=float, default=1e-04, help="Learning Rate" + ) + parser.add_argument( + "--resume", + nargs="?", + type=str, + default=None, + help="Path to previous saved model to restart from", + ) + parser.add_argument("--workers", type=int, default=8, help="max dataloader workers") + parser.add_argument( + "--project", default=ROOT / "runs/train", help="save to project/name" + ) + parser.add_argument("--name", default="exp", help="save to project/name") + parser.add_argument( + "--exist-ok", + action="store_true", + help="existing project/name ok, do not increment", + ) + parser.add_argument("--seed", type=int, default=0, help="Global training seed") + parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger") + args = parser.parse_args() + main(args) diff --git a/doc_dewarp/train.sh b/doc_dewarp/train.sh new file mode 100644 index 0000000..2b7da99 --- /dev/null +++ b/doc_dewarp/train.sh @@ -0,0 +1,10 @@ +export OPENCV_IO_ENABLE_OPENEXR=1 +export FLAGS_logtostderr=0 +export CUDA_VISIBLE_DEVICES=0 + +python train.py --img-size 288 \ + --name "DocTr++" \ + --batch-size 12 \ + --lr 1e-4 \ + --exist-ok \ + --use-vdl diff --git a/doc_dewarp/utils.py b/doc_dewarp/utils.py new file mode 100644 index 0000000..e887908 --- /dev/null +++ b/doc_dewarp/utils.py @@ -0,0 +1,67 @@ +import numpy as np +import paddle +from paddle.nn import functional as F + + +def to_tensor(img: np.ndarray): + """ + Converts a numpy array image (HWC) to a Paddle tensor (NCHW). + + Args: + img (numpy.ndarray): The input image as a numpy array. + + Returns: + out (paddle.Tensor): The output tensor. + """ + img = img[:, :, ::-1] + img = img.astype("float32") / 255.0 + img = img.transpose(2, 0, 1) + out: paddle.Tensor = paddle.to_tensor(img) + out = paddle.unsqueeze(out, axis=0) + + return out + + +def to_image(x: paddle.Tensor): + """ + Converts a Paddle tensor (NCHW) to a numpy array image (HWC). + + Args: + x (paddle.Tensor): The input tensor. + + Returns: + out (numpy.ndarray): The output image as a numpy array. + """ + out: np.ndarray = x.squeeze().numpy() + out = out.transpose(1, 2, 0) + out = out * 255.0 + out = out.astype("uint8") + out = out[:, :, ::-1] + + return out + + +def unwarp(img, bm, bm_data_format="NCHW"): + """ + Unwarp an image using a flow field. + + Args: + img (paddle.Tensor): The input image. + bm (paddle.Tensor): The flow field. + + Returns: + out (paddle.Tensor): The output image. + """ + _, _, h, w = img.shape + + if bm_data_format == "NHWC": + bm = bm.transpose([0, 3, 1, 2]) + + # NCHW + bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True) + # NHWC + bm = bm.transpose([0, 2, 3, 1]) + # NCHW + out = F.grid_sample(img, bm) + + return out diff --git a/doc_dewarp/weight_init.py b/doc_dewarp/weight_init.py new file mode 100644 index 0000000..c83178c --- /dev/null +++ b/doc_dewarp/weight_init.py @@ -0,0 +1,152 @@ +import math + +import numpy as np +import paddle +import paddle.nn.initializer as init +from scipy import special + + +def weight_init_( + layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs +): + """ + In-place params init function. + Usage: + .. code-block:: python + + import paddle + import numpy as np + + data = np.ones([3, 4], dtype='float32') + linear = paddle.nn.Linear(4, 4) + input = paddle.to_tensor(data) + print(linear.weight) + linear(input) + + weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1) + print(linear.weight) + """ + + if hasattr(layer, "weight") and layer.weight is not None: + getattr(init, func)(**kwargs)(layer.weight) + if weight_name is not None: + # override weight name + layer.weight.name = weight_name + + if hasattr(layer, "bias") and layer.bias is not None: + init.Constant(bias_value)(layer.bias) + if bias_name is not None: + # override bias name + layer.bias.name = bias_name + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + print( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect." + ) + + with paddle.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + lower = norm_cdf((a - mean) / std) + upper = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1]. + tmp = np.random.uniform( + 2 * lower - 1, 2 * upper - 1, size=list(tensor.shape) + ).astype(np.float32) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tmp = special.erfinv(tmp) + + # Transform to proper mean, std + tmp *= std * math.sqrt(2.0) + tmp += mean + + # Clamp to ensure it's in the proper range + tmp = np.clip(tmp, a, b) + tensor.set_value(paddle.to_tensor(tmp)) + + return tensor + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError( + "Fan in and fan out can not be computed for tensor " + "with fewer than 2 dimensions" + ) + + num_input_fmaps = tensor.shape[1] + num_output_fmaps = tensor.shape[0] + receptive_field_size = 1 + if tensor.dim() > 2: + receptive_field_size = tensor[0][0].numel() + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"): + def _calculate_correct_fan(tensor, mode): + mode = mode.lower() + valid_modes = ["fan_in", "fan_out"] + if mode not in valid_modes: + raise ValueError( + "Mode {} not supported, please use one of {}".format(mode, valid_modes) + ) + + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + return fan_in if mode == "fan_in" else fan_out + + def calculate_gain(nonlinearity, param=None): + linear_fns = [ + "linear", + "conv1d", + "conv2d", + "conv3d", + "conv_transpose1d", + "conv_transpose2d", + "conv_transpose3d", + ] + if nonlinearity in linear_fns or nonlinearity == "sigmoid": + return 1 + elif nonlinearity == "tanh": + return 5.0 / 3 + elif nonlinearity == "relu": + return math.sqrt(2.0) + elif nonlinearity == "leaky_relu": + if param is None: + negative_slope = 0.01 + elif ( + not isinstance(param, bool) + and isinstance(param, int) + or isinstance(param, float) + ): + negative_slope = param + else: + raise ValueError("negative_slope {} not a valid number".format(param)) + return math.sqrt(2.0 / (1 + negative_slope**2)) + else: + raise ValueError("Unsupported nonlinearity {}".format(nonlinearity)) + + fan = _calculate_correct_fan(tensor, mode) + gain = calculate_gain(nonlinearity, a) + std = gain / math.sqrt(fan) + with paddle.no_grad(): + paddle.nn.initializer.Normal(0, std)(tensor) + return tensor