DocTr去扭曲
This commit is contained in:
@@ -1,96 +0,0 @@
|
|||||||
import os
|
|
||||||
|
|
||||||
import cv2
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.autograd import Variable
|
|
||||||
|
|
||||||
from dewarp.models import get_model
|
|
||||||
from dewarp.utils import convert_state_dict
|
|
||||||
|
|
||||||
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
||||||
|
|
||||||
|
|
||||||
def unwarp(img, bm):
|
|
||||||
w, h = img.shape[0], img.shape[1]
|
|
||||||
bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :]
|
|
||||||
bm0 = cv2.blur(bm[:, :, 0], (3, 3))
|
|
||||||
bm1 = cv2.blur(bm[:, :, 1], (3, 3))
|
|
||||||
bm0 = cv2.resize(bm0, (h, w))
|
|
||||||
bm1 = cv2.resize(bm1, (h, w))
|
|
||||||
bm = np.stack([bm0, bm1], axis=-1)
|
|
||||||
bm = np.expand_dims(bm, 0)
|
|
||||||
bm = torch.from_numpy(bm).double()
|
|
||||||
|
|
||||||
img = img.astype(float) / 255.0
|
|
||||||
img = img.transpose((2, 0, 1))
|
|
||||||
img = np.expand_dims(img, 0)
|
|
||||||
img = torch.from_numpy(img).double()
|
|
||||||
|
|
||||||
res = F.grid_sample(input=img, grid=bm, align_corners=True)
|
|
||||||
res = res[0].numpy().transpose((1, 2, 0))
|
|
||||||
|
|
||||||
return res
|
|
||||||
|
|
||||||
|
|
||||||
def dewarp_image(image):
|
|
||||||
wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl"
|
|
||||||
bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl"
|
|
||||||
wc_model_file_name = os.path.split(wc_model_path)[1]
|
|
||||||
wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')]
|
|
||||||
|
|
||||||
bm_model_file_name = os.path.split(bm_model_path)[1]
|
|
||||||
bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')]
|
|
||||||
|
|
||||||
wc_n_classes = 3
|
|
||||||
bm_n_classes = 2
|
|
||||||
|
|
||||||
wc_img_size = (256, 256)
|
|
||||||
bm_img_size = (128, 128)
|
|
||||||
|
|
||||||
# Setup image
|
|
||||||
imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
||||||
img = cv2.resize(imgorg, wc_img_size)
|
|
||||||
|
|
||||||
img = img[:, :, ::-1]
|
|
||||||
img = img.astype(float) / 255.0
|
|
||||||
img = img.transpose(2, 0, 1) # NHWC -> NCHW
|
|
||||||
img = np.expand_dims(img, 0)
|
|
||||||
img = torch.from_numpy(img).float()
|
|
||||||
|
|
||||||
# Predict
|
|
||||||
htan = nn.Hardtanh(0, 1.0)
|
|
||||||
wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3)
|
|
||||||
if DEVICE.type == 'cpu':
|
|
||||||
wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state'])
|
|
||||||
else:
|
|
||||||
wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state'])
|
|
||||||
wc_model.load_state_dict(wc_state)
|
|
||||||
wc_model.eval()
|
|
||||||
bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3)
|
|
||||||
if DEVICE.type == 'cpu':
|
|
||||||
bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state'])
|
|
||||||
else:
|
|
||||||
bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state'])
|
|
||||||
bm_model.load_state_dict(bm_state)
|
|
||||||
bm_model.eval()
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
wc_model.cuda()
|
|
||||||
bm_model.cuda()
|
|
||||||
images = Variable(img.cuda())
|
|
||||||
else:
|
|
||||||
images = Variable(img)
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
wc_outputs = wc_model(images)
|
|
||||||
pred_wc = htan(wc_outputs)
|
|
||||||
bm_input = F.interpolate(pred_wc, bm_img_size)
|
|
||||||
outputs_bm = bm_model(bm_input)
|
|
||||||
|
|
||||||
# call unwarp
|
|
||||||
uwpred = unwarp(imgorg, outputs_bm)
|
|
||||||
uwpred = (uwpred * 255).astype(np.uint8)
|
|
||||||
return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR)
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
from dewarp.models.densenetccnl import dnetccnl
|
|
||||||
from dewarp.models.unetnc import UnetGenerator
|
|
||||||
|
|
||||||
|
|
||||||
def get_model(name, n_classes=1, in_channels=3):
|
|
||||||
model = _get_model_instance(name)
|
|
||||||
|
|
||||||
if name == 'dnetccnl':
|
|
||||||
model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32)
|
|
||||||
elif name == 'unetnc':
|
|
||||||
model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7)
|
|
||||||
else:
|
|
||||||
model = model(n_classes=n_classes)
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _get_model_instance(name):
|
|
||||||
try:
|
|
||||||
return {
|
|
||||||
'dnetccnl': dnetccnl,
|
|
||||||
'unetnc': UnetGenerator,
|
|
||||||
}[name]
|
|
||||||
except:
|
|
||||||
print('Model {} not available'.format(name))
|
|
||||||
@@ -1,243 +0,0 @@
|
|||||||
# Densenet decoder encoder with intermediate fully connected layers and dropout
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
def add_coordConv_channels(t):
|
|
||||||
n, c, h, w = t.size()
|
|
||||||
xx_channel = np.ones((h, w))
|
|
||||||
xx_range = np.array(range(h))
|
|
||||||
xx_range = np.expand_dims(xx_range, -1)
|
|
||||||
xx_coord = xx_channel * xx_range
|
|
||||||
yy_coord = xx_coord.transpose()
|
|
||||||
|
|
||||||
xx_coord = xx_coord / (h - 1)
|
|
||||||
yy_coord = yy_coord / (h - 1)
|
|
||||||
xx_coord = xx_coord * 2 - 1
|
|
||||||
yy_coord = yy_coord * 2 - 1
|
|
||||||
xx_coord = torch.from_numpy(xx_coord).float()
|
|
||||||
yy_coord = torch.from_numpy(yy_coord).float()
|
|
||||||
|
|
||||||
if t.is_cuda:
|
|
||||||
xx_coord = xx_coord.cuda()
|
|
||||||
yy_coord = yy_coord.cuda()
|
|
||||||
|
|
||||||
xx_coord = xx_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1)
|
|
||||||
yy_coord = yy_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1)
|
|
||||||
|
|
||||||
t_cc = torch.cat((t, xx_coord, yy_coord), dim=1)
|
|
||||||
|
|
||||||
return t_cc
|
|
||||||
|
|
||||||
|
|
||||||
class DenseBlockEncoder(nn.Module):
|
|
||||||
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
|
|
||||||
super(DenseBlockEncoder, self).__init__()
|
|
||||||
assert (n_convs > 0)
|
|
||||||
|
|
||||||
self.n_channels = n_channels
|
|
||||||
self.n_convs = n_convs
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for i in range(n_convs):
|
|
||||||
self.layers.append(nn.Sequential(
|
|
||||||
nn.BatchNorm2d(n_channels),
|
|
||||||
activation(*args),
|
|
||||||
nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), ))
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
if i > 0:
|
|
||||||
next_output = 0
|
|
||||||
for no in outputs:
|
|
||||||
next_output = next_output + no
|
|
||||||
outputs.append(next_output)
|
|
||||||
else:
|
|
||||||
outputs.append(layer(inputs))
|
|
||||||
return outputs[-1]
|
|
||||||
|
|
||||||
|
|
||||||
# Dense block in encoder.
|
|
||||||
class DenseBlockDecoder(nn.Module):
|
|
||||||
def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]):
|
|
||||||
super(DenseBlockDecoder, self).__init__()
|
|
||||||
assert (n_convs > 0)
|
|
||||||
|
|
||||||
self.n_channels = n_channels
|
|
||||||
self.n_convs = n_convs
|
|
||||||
self.layers = nn.ModuleList()
|
|
||||||
for i in range(n_convs):
|
|
||||||
self.layers.append(nn.Sequential(
|
|
||||||
nn.BatchNorm2d(n_channels),
|
|
||||||
activation(*args),
|
|
||||||
nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), ))
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
outputs = []
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
if i > 0:
|
|
||||||
next_output = 0
|
|
||||||
for no in outputs:
|
|
||||||
next_output = next_output + no
|
|
||||||
outputs.append(next_output)
|
|
||||||
else:
|
|
||||||
outputs.append(layer(inputs))
|
|
||||||
return outputs[-1]
|
|
||||||
|
|
||||||
|
|
||||||
class DenseTransitionBlockEncoder(nn.Module):
|
|
||||||
def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]):
|
|
||||||
super(DenseTransitionBlockEncoder, self).__init__()
|
|
||||||
self.n_channels_in = n_channels_in
|
|
||||||
self.n_channels_out = n_channels_out
|
|
||||||
self.mp = mp
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(n_channels_in),
|
|
||||||
activation(*args),
|
|
||||||
nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False),
|
|
||||||
nn.MaxPool2d(mp),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.main(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class DenseTransitionBlockDecoder(nn.Module):
|
|
||||||
def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]):
|
|
||||||
super(DenseTransitionBlockDecoder, self).__init__()
|
|
||||||
self.n_channels_in = n_channels_in
|
|
||||||
self.n_channels_out = n_channels_out
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
nn.BatchNorm2d(n_channels_in),
|
|
||||||
activation(*args),
|
|
||||||
nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
return self.main(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
## Dense encoders and decoders for image of size 128 128
|
|
||||||
class waspDenseEncoder128(nn.Module):
|
|
||||||
def __init__(self, nc=1, ndf=32, ndim=128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh,
|
|
||||||
f_args=[]):
|
|
||||||
super(waspDenseEncoder128, self).__init__()
|
|
||||||
self.ndim = ndim
|
|
||||||
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
# input is (nc) x 128 x 128
|
|
||||||
nn.BatchNorm2d(nc),
|
|
||||||
nn.ReLU(True),
|
|
||||||
nn.Conv2d(nc, ndf, 4, stride=2, padding=1),
|
|
||||||
|
|
||||||
# state size. (ndf) x 64 x 64
|
|
||||||
DenseBlockEncoder(ndf, 6),
|
|
||||||
DenseTransitionBlockEncoder(ndf, ndf * 2, 2, activation=activation, args=args),
|
|
||||||
|
|
||||||
# state size. (ndf*2) x 32 x 32
|
|
||||||
DenseBlockEncoder(ndf * 2, 12),
|
|
||||||
DenseTransitionBlockEncoder(ndf * 2, ndf * 4, 2, activation=activation, args=args),
|
|
||||||
|
|
||||||
# state size. (ndf*4) x 16 x 16
|
|
||||||
DenseBlockEncoder(ndf * 4, 16),
|
|
||||||
DenseTransitionBlockEncoder(ndf * 4, ndf * 8, 2, activation=activation, args=args),
|
|
||||||
|
|
||||||
# state size. (ndf*4) x 8 x 8
|
|
||||||
DenseBlockEncoder(ndf * 8, 16),
|
|
||||||
DenseTransitionBlockEncoder(ndf * 8, ndf * 8, 2, activation=activation, args=args),
|
|
||||||
|
|
||||||
# state size. (ndf*8) x 4 x 4
|
|
||||||
DenseBlockEncoder(ndf * 8, 16),
|
|
||||||
DenseTransitionBlockEncoder(ndf * 8, ndim, 4, activation=activation, args=args),
|
|
||||||
f_activation(*f_args),
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
input = add_coordConv_channels(input)
|
|
||||||
output = self.main(input).view(-1, self.ndim)
|
|
||||||
# print(output.size())
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
class waspDenseDecoder128(nn.Module):
|
|
||||||
def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh,
|
|
||||||
f_args=[]):
|
|
||||||
super(waspDenseDecoder128, self).__init__()
|
|
||||||
self.main = nn.Sequential(
|
|
||||||
# input is Z, going into convolution
|
|
||||||
nn.BatchNorm2d(nz),
|
|
||||||
activation(*args),
|
|
||||||
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
|
||||||
|
|
||||||
# state size. (ngf*8) x 4 x 4
|
|
||||||
DenseBlockDecoder(ngf * 8, 16),
|
|
||||||
DenseTransitionBlockDecoder(ngf * 8, ngf * 8),
|
|
||||||
|
|
||||||
# state size. (ngf*4) x 8 x 8
|
|
||||||
DenseBlockDecoder(ngf * 8, 16),
|
|
||||||
DenseTransitionBlockDecoder(ngf * 8, ngf * 4),
|
|
||||||
|
|
||||||
# state size. (ngf*2) x 16 x 16
|
|
||||||
DenseBlockDecoder(ngf * 4, 12),
|
|
||||||
DenseTransitionBlockDecoder(ngf * 4, ngf * 2),
|
|
||||||
|
|
||||||
# state size. (ngf) x 32 x 32
|
|
||||||
DenseBlockDecoder(ngf * 2, 6),
|
|
||||||
DenseTransitionBlockDecoder(ngf * 2, ngf),
|
|
||||||
|
|
||||||
# state size. (ngf) x 64 x 64
|
|
||||||
DenseBlockDecoder(ngf, 6),
|
|
||||||
DenseTransitionBlockDecoder(ngf, ngf),
|
|
||||||
|
|
||||||
# state size (ngf) x 128 x 128
|
|
||||||
nn.BatchNorm2d(ngf),
|
|
||||||
activation(*args),
|
|
||||||
nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False),
|
|
||||||
f_activation(*f_args),
|
|
||||||
)
|
|
||||||
# self.smooth=nn.Sequential(
|
|
||||||
# nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False),
|
|
||||||
# f_activation(*f_args),
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
# return self.smooth(self.main(inputs))
|
|
||||||
return self.main(inputs)
|
|
||||||
|
|
||||||
|
|
||||||
class dnetccnl(nn.Module):
|
|
||||||
# in_channels -> nc | encoder first layer
|
|
||||||
# filters -> ndf | encoder first layer
|
|
||||||
# img_size(h,w) -> ndim
|
|
||||||
# out_channels -> optical flow (x,y)
|
|
||||||
|
|
||||||
def __init__(self, img_size=128, in_channels=1, out_channels=2, filters=32, fc_units=100):
|
|
||||||
super(dnetccnl, self).__init__()
|
|
||||||
self.nc = in_channels
|
|
||||||
self.nf = filters
|
|
||||||
self.ndim = img_size
|
|
||||||
self.oc = out_channels
|
|
||||||
self.fcu = fc_units
|
|
||||||
|
|
||||||
self.encoder = waspDenseEncoder128(nc=self.nc + 2, ndf=self.nf, ndim=self.ndim)
|
|
||||||
self.decoder = waspDenseDecoder128(nz=self.ndim, nc=self.oc, ngf=self.nf)
|
|
||||||
# self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu),
|
|
||||||
# nn.ReLU(True),
|
|
||||||
# nn.Dropout(0.25),
|
|
||||||
# nn.Linear(self.fcu,self.ndim),
|
|
||||||
# nn.ReLU(True),
|
|
||||||
# nn.Dropout(0.25),
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
|
||||||
encoded = self.encoder(inputs)
|
|
||||||
encoded = encoded.unsqueeze(-1).unsqueeze(-1)
|
|
||||||
decoded = self.decoder(encoded)
|
|
||||||
# print torch.max(decoded)
|
|
||||||
# print torch.min(decoded)
|
|
||||||
|
|
||||||
return decoded
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
import functools
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
# Defines the Unet generator.
|
|
||||||
# |num_downs|: number of downsamplings in UNet. For example,
|
|
||||||
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
|
||||||
# at the bottleneck
|
|
||||||
class UnetGenerator(nn.Module):
|
|
||||||
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
|
||||||
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
||||||
super(UnetGenerator, self).__init__()
|
|
||||||
|
|
||||||
# construct unet structure
|
|
||||||
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer,
|
|
||||||
innermost=True)
|
|
||||||
for i in range(num_downs - 5):
|
|
||||||
unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block,
|
|
||||||
norm_layer=norm_layer, use_dropout=use_dropout)
|
|
||||||
unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block,
|
|
||||||
norm_layer=norm_layer)
|
|
||||||
unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block,
|
|
||||||
norm_layer=norm_layer)
|
|
||||||
unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
|
|
||||||
unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True,
|
|
||||||
norm_layer=norm_layer)
|
|
||||||
|
|
||||||
self.model = unet_block
|
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return self.model(input)
|
|
||||||
|
|
||||||
|
|
||||||
# Defines the submodule with skip connection.
|
|
||||||
# X -------------------identity---------------------- X
|
|
||||||
# |-- downsampling -- |submodule| -- upsampling --|
|
|
||||||
class UnetSkipConnectionBlock(nn.Module):
|
|
||||||
def __init__(self, outer_nc, inner_nc, input_nc=None,
|
|
||||||
submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
||||||
super(UnetSkipConnectionBlock, self).__init__()
|
|
||||||
self.outermost = outermost
|
|
||||||
if type(norm_layer) == functools.partial:
|
|
||||||
use_bias = norm_layer.func == nn.InstanceNorm2d
|
|
||||||
else:
|
|
||||||
use_bias = norm_layer == nn.InstanceNorm2d
|
|
||||||
if input_nc is None:
|
|
||||||
input_nc = outer_nc
|
|
||||||
downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
|
|
||||||
stride=2, padding=1, bias=use_bias)
|
|
||||||
downrelu = nn.LeakyReLU(0.2, True)
|
|
||||||
downnorm = norm_layer(inner_nc)
|
|
||||||
uprelu = nn.ReLU(True)
|
|
||||||
upnorm = norm_layer(outer_nc)
|
|
||||||
|
|
||||||
if outermost:
|
|
||||||
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
||||||
kernel_size=4, stride=2,
|
|
||||||
padding=1)
|
|
||||||
down = [downconv]
|
|
||||||
up = [uprelu, upconv, nn.Tanh()]
|
|
||||||
model = down + [submodule] + up
|
|
||||||
elif innermost:
|
|
||||||
upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
|
|
||||||
kernel_size=4, stride=2,
|
|
||||||
padding=1, bias=use_bias)
|
|
||||||
down = [downrelu, downconv]
|
|
||||||
up = [uprelu, upconv, upnorm]
|
|
||||||
model = down + up
|
|
||||||
else:
|
|
||||||
upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
|
|
||||||
kernel_size=4, stride=2,
|
|
||||||
padding=1, bias=use_bias)
|
|
||||||
down = [downrelu, downconv, downnorm]
|
|
||||||
up = [uprelu, upconv, upnorm]
|
|
||||||
|
|
||||||
if use_dropout:
|
|
||||||
model = down + [submodule] + up + [nn.Dropout(0.5)]
|
|
||||||
else:
|
|
||||||
model = down + [submodule] + up
|
|
||||||
|
|
||||||
self.model = nn.Sequential(*model)
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
if self.outermost:
|
|
||||||
return self.model(x)
|
|
||||||
else:
|
|
||||||
return torch.cat([x, self.model(x)], 1)
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
'''
|
|
||||||
Misc Utility functions
|
|
||||||
'''
|
|
||||||
from collections import OrderedDict
|
|
||||||
|
|
||||||
|
|
||||||
def convert_state_dict(state_dict):
|
|
||||||
"""Converts a state dict saved from a dataParallel module to normal
|
|
||||||
module state_dict inplace
|
|
||||||
:param state_dict is the loaded DataParallel model_state
|
|
||||||
|
|
||||||
"""
|
|
||||||
new_state_dict = OrderedDict()
|
|
||||||
for k, v in state_dict.items():
|
|
||||||
name = k[7:] # remove `module.`
|
|
||||||
new_state_dict[name] = v
|
|
||||||
return new_state_dict
|
|
||||||
167
doc_dewarp/.gitignore
vendored
Normal file
167
doc_dewarp/.gitignore
vendored
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
input/
|
||||||
|
output/
|
||||||
|
|
||||||
|
runs/
|
||||||
|
*_dump/
|
||||||
|
*_log/
|
||||||
|
|
||||||
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# C extensions
|
||||||
|
*.so
|
||||||
|
|
||||||
|
# Distribution / packaging
|
||||||
|
.Python
|
||||||
|
build/
|
||||||
|
develop-eggs/
|
||||||
|
dist/
|
||||||
|
downloads/
|
||||||
|
eggs/
|
||||||
|
.eggs/
|
||||||
|
lib/
|
||||||
|
lib64/
|
||||||
|
parts/
|
||||||
|
sdist/
|
||||||
|
var/
|
||||||
|
wheels/
|
||||||
|
share/python-wheels/
|
||||||
|
*.egg-info/
|
||||||
|
.installed.cfg
|
||||||
|
*.egg
|
||||||
|
MANIFEST
|
||||||
|
|
||||||
|
# PyInstaller
|
||||||
|
# Usually these files are written by a python script from a template
|
||||||
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
|
*.manifest
|
||||||
|
*.spec
|
||||||
|
|
||||||
|
# Installer logs
|
||||||
|
pip-log.txt
|
||||||
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
|
# Unit test / coverage reports
|
||||||
|
htmlcov/
|
||||||
|
.tox/
|
||||||
|
.nox/
|
||||||
|
.coverage
|
||||||
|
.coverage.*
|
||||||
|
.cache
|
||||||
|
nosetests.xml
|
||||||
|
coverage.xml
|
||||||
|
*.cover
|
||||||
|
*.py,cover
|
||||||
|
.hypothesis/
|
||||||
|
.pytest_cache/
|
||||||
|
cover/
|
||||||
|
|
||||||
|
# Translations
|
||||||
|
*.mo
|
||||||
|
*.pot
|
||||||
|
|
||||||
|
# Django stuff:
|
||||||
|
*.log
|
||||||
|
local_settings.py
|
||||||
|
db.sqlite3
|
||||||
|
db.sqlite3-journal
|
||||||
|
|
||||||
|
# Flask stuff:
|
||||||
|
instance/
|
||||||
|
.webassets-cache
|
||||||
|
|
||||||
|
# Scrapy stuff:
|
||||||
|
.scrapy
|
||||||
|
|
||||||
|
# Sphinx documentation
|
||||||
|
docs/_build/
|
||||||
|
|
||||||
|
# PyBuilder
|
||||||
|
.pybuilder/
|
||||||
|
target/
|
||||||
|
|
||||||
|
# Jupyter Notebook
|
||||||
|
.ipynb_checkpoints
|
||||||
|
|
||||||
|
# IPython
|
||||||
|
profile_default/
|
||||||
|
ipython_config.py
|
||||||
|
|
||||||
|
# pyenv
|
||||||
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
|
# .python-version
|
||||||
|
|
||||||
|
# pipenv
|
||||||
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
|
# install all needed dependencies.
|
||||||
|
#Pipfile.lock
|
||||||
|
|
||||||
|
# poetry
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
|
# commonly ignored for libraries.
|
||||||
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
|
#poetry.lock
|
||||||
|
|
||||||
|
# pdm
|
||||||
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
|
#pdm.lock
|
||||||
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
|
# in version control.
|
||||||
|
# https://pdm.fming.dev/#use-with-ide
|
||||||
|
.pdm.toml
|
||||||
|
|
||||||
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
|
__pypackages__/
|
||||||
|
|
||||||
|
# Celery stuff
|
||||||
|
celerybeat-schedule
|
||||||
|
celerybeat.pid
|
||||||
|
|
||||||
|
# SageMath parsed files
|
||||||
|
*.sage.py
|
||||||
|
|
||||||
|
# Environments
|
||||||
|
.env
|
||||||
|
.venv
|
||||||
|
env/
|
||||||
|
venv/
|
||||||
|
ENV/
|
||||||
|
env.bak/
|
||||||
|
venv.bak/
|
||||||
|
|
||||||
|
# Spyder project settings
|
||||||
|
.spyderproject
|
||||||
|
.spyproject
|
||||||
|
|
||||||
|
# Rope project settings
|
||||||
|
.ropeproject
|
||||||
|
|
||||||
|
# mkdocs documentation
|
||||||
|
/site
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
||||||
|
# Pyre type checker
|
||||||
|
.pyre/
|
||||||
|
|
||||||
|
# pytype static type analyzer
|
||||||
|
.pytype/
|
||||||
|
|
||||||
|
# Cython debug symbols
|
||||||
|
cython_debug/
|
||||||
|
|
||||||
|
# PyCharm
|
||||||
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
|
.idea/
|
||||||
34
doc_dewarp/.pre-commit-config.yaml
Normal file
34
doc_dewarp/.pre-commit-config.yaml
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
repos:
|
||||||
|
# Common hooks
|
||||||
|
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||||
|
rev: v4.4.0
|
||||||
|
hooks:
|
||||||
|
- id: check-added-large-files
|
||||||
|
- id: check-merge-conflict
|
||||||
|
- id: check-symlinks
|
||||||
|
- id: detect-private-key
|
||||||
|
- id: end-of-file-fixer
|
||||||
|
- id: trailing-whitespace
|
||||||
|
- repo: https://github.com/Lucas-C/pre-commit-hooks.git
|
||||||
|
rev: v1.5.1
|
||||||
|
hooks:
|
||||||
|
- id: remove-crlf
|
||||||
|
- id: remove-tabs
|
||||||
|
name: Tabs remover (Python)
|
||||||
|
files: (.*\.(py|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||||
|
args: [--whitespaces-count, '4']
|
||||||
|
# For Python files
|
||||||
|
- repo: https://github.com/psf/black.git
|
||||||
|
rev: 23.3.0
|
||||||
|
hooks:
|
||||||
|
- id: black
|
||||||
|
files: (.*\.(py|pyi|bzl)|BUILD|.*\.BUILD|WORKSPACE)$
|
||||||
|
- repo: https://github.com/pycqa/isort
|
||||||
|
rev: 5.11.5
|
||||||
|
hooks:
|
||||||
|
- id: isort
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.0.272
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix, --exit-non-zero-on-fix, --no-cache]
|
||||||
398
doc_dewarp/GeoTr.py
Normal file
398
doc_dewarp/GeoTr.py
Normal file
@@ -0,0 +1,398 @@
|
|||||||
|
import copy
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
from .extractor import BasicEncoder
|
||||||
|
from .position_encoding import build_position_encoding
|
||||||
|
from .weight_init import weight_init_
|
||||||
|
|
||||||
|
|
||||||
|
class attnLayer(nn.Layer):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model,
|
||||||
|
nhead=8,
|
||||||
|
dim_feedforward=2048,
|
||||||
|
dropout=0.1,
|
||||||
|
activation="relu",
|
||||||
|
normalize_before=False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
|
||||||
|
self.multihead_attn_list = nn.LayerList(
|
||||||
|
[
|
||||||
|
copy.deepcopy(nn.MultiHeadAttention(d_model, nhead, dropout=dropout))
|
||||||
|
for i in range(2)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
||||||
|
|
||||||
|
self.norm1 = nn.LayerNorm(d_model)
|
||||||
|
self.norm2_list = nn.LayerList(
|
||||||
|
[copy.deepcopy(nn.LayerNorm(d_model)) for i in range(2)]
|
||||||
|
)
|
||||||
|
|
||||||
|
self.norm3 = nn.LayerNorm(d_model)
|
||||||
|
self.dropout1 = nn.Dropout(p=dropout)
|
||||||
|
self.dropout2_list = nn.LayerList(
|
||||||
|
[copy.deepcopy(nn.Dropout(p=dropout)) for i in range(2)]
|
||||||
|
)
|
||||||
|
self.dropout3 = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
self.activation = _get_activation_fn(activation)
|
||||||
|
self.normalize_before = normalize_before
|
||||||
|
|
||||||
|
def with_pos_embed(self, tensor, pos: Optional[paddle.Tensor]):
|
||||||
|
return tensor if pos is None else tensor + pos
|
||||||
|
|
||||||
|
def forward_post(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory_list,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
pos=None,
|
||||||
|
memory_pos=None,
|
||||||
|
):
|
||||||
|
q = k = self.with_pos_embed(tgt, pos)
|
||||||
|
tgt2 = self.self_attn(
|
||||||
|
q.transpose((1, 0, 2)),
|
||||||
|
k.transpose((1, 0, 2)),
|
||||||
|
value=tgt.transpose((1, 0, 2)),
|
||||||
|
attn_mask=tgt_mask,
|
||||||
|
)
|
||||||
|
tgt2 = tgt2.transpose((1, 0, 2))
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt = self.norm1(tgt)
|
||||||
|
|
||||||
|
for memory, multihead_attn, norm2, dropout2, m_pos in zip(
|
||||||
|
memory_list,
|
||||||
|
self.multihead_attn_list,
|
||||||
|
self.norm2_list,
|
||||||
|
self.dropout2_list,
|
||||||
|
memory_pos,
|
||||||
|
):
|
||||||
|
tgt2 = multihead_attn(
|
||||||
|
query=self.with_pos_embed(tgt, pos).transpose((1, 0, 2)),
|
||||||
|
key=self.with_pos_embed(memory, m_pos).transpose((1, 0, 2)),
|
||||||
|
value=memory.transpose((1, 0, 2)),
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
).transpose((1, 0, 2))
|
||||||
|
|
||||||
|
tgt = tgt + dropout2(tgt2)
|
||||||
|
tgt = norm2(tgt)
|
||||||
|
|
||||||
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
tgt = self.norm3(tgt)
|
||||||
|
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
def forward_pre(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
pos=None,
|
||||||
|
memory_pos=None,
|
||||||
|
):
|
||||||
|
tgt2 = self.norm1(tgt)
|
||||||
|
|
||||||
|
q = k = self.with_pos_embed(tgt2, pos)
|
||||||
|
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask)
|
||||||
|
tgt = tgt + self.dropout1(tgt2)
|
||||||
|
tgt2 = self.norm2(tgt)
|
||||||
|
|
||||||
|
tgt2 = self.multihead_attn(
|
||||||
|
query=self.with_pos_embed(tgt2, pos),
|
||||||
|
key=self.with_pos_embed(memory, memory_pos),
|
||||||
|
value=memory,
|
||||||
|
attn_mask=memory_mask,
|
||||||
|
)
|
||||||
|
tgt = tgt + self.dropout2(tgt2)
|
||||||
|
tgt2 = self.norm3(tgt)
|
||||||
|
|
||||||
|
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||||
|
tgt = tgt + self.dropout3(tgt2)
|
||||||
|
|
||||||
|
return tgt
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
tgt,
|
||||||
|
memory_list,
|
||||||
|
tgt_mask=None,
|
||||||
|
memory_mask=None,
|
||||||
|
pos=None,
|
||||||
|
memory_pos=None,
|
||||||
|
):
|
||||||
|
if self.normalize_before:
|
||||||
|
return self.forward_pre(
|
||||||
|
tgt,
|
||||||
|
memory_list,
|
||||||
|
tgt_mask,
|
||||||
|
memory_mask,
|
||||||
|
pos,
|
||||||
|
memory_pos,
|
||||||
|
)
|
||||||
|
return self.forward_post(
|
||||||
|
tgt,
|
||||||
|
memory_list,
|
||||||
|
tgt_mask,
|
||||||
|
memory_mask,
|
||||||
|
pos,
|
||||||
|
memory_pos,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clones(module, N):
|
||||||
|
return nn.LayerList([copy.deepcopy(module) for i in range(N)])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_activation_fn(activation):
|
||||||
|
"""Return an activation function given a string"""
|
||||||
|
if activation == "relu":
|
||||||
|
return F.relu
|
||||||
|
if activation == "gelu":
|
||||||
|
return F.gelu
|
||||||
|
if activation == "glu":
|
||||||
|
return F.glu
|
||||||
|
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
||||||
|
|
||||||
|
|
||||||
|
class TransDecoder(nn.Layer):
|
||||||
|
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||||
|
super(TransDecoder, self).__init__()
|
||||||
|
|
||||||
|
attn_layer = attnLayer(hidden_dim)
|
||||||
|
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||||
|
self.position_embedding = build_position_encoding(hidden_dim)
|
||||||
|
|
||||||
|
def forward(self, image: paddle.Tensor, query_embed: paddle.Tensor):
|
||||||
|
pos = self.position_embedding(
|
||||||
|
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||||
|
)
|
||||||
|
|
||||||
|
b, c, h, w = image.shape
|
||||||
|
|
||||||
|
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||||
|
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
query_embed = layer(query_embed, [image], pos=pos, memory_pos=[pos, pos])
|
||||||
|
|
||||||
|
query_embed = query_embed.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||||
|
|
||||||
|
return query_embed
|
||||||
|
|
||||||
|
|
||||||
|
class TransEncoder(nn.Layer):
|
||||||
|
def __init__(self, num_attn_layers: int, hidden_dim: int = 128):
|
||||||
|
super(TransEncoder, self).__init__()
|
||||||
|
|
||||||
|
attn_layer = attnLayer(hidden_dim)
|
||||||
|
self.layers = _get_clones(attn_layer, num_attn_layers)
|
||||||
|
self.position_embedding = build_position_encoding(hidden_dim)
|
||||||
|
|
||||||
|
def forward(self, image: paddle.Tensor):
|
||||||
|
pos = self.position_embedding(
|
||||||
|
paddle.ones([image.shape[0], image.shape[2], image.shape[3]], dtype="bool")
|
||||||
|
)
|
||||||
|
|
||||||
|
b, c, h, w = image.shape
|
||||||
|
|
||||||
|
image = image.flatten(2).transpose(perm=[2, 0, 1])
|
||||||
|
pos = pos.flatten(2).transpose(perm=[2, 0, 1])
|
||||||
|
|
||||||
|
for layer in self.layers:
|
||||||
|
image = layer(image, [image], pos=pos, memory_pos=[pos, pos])
|
||||||
|
|
||||||
|
image = image.transpose(perm=[1, 2, 0]).reshape([b, c, h, w])
|
||||||
|
|
||||||
|
return image
|
||||||
|
|
||||||
|
|
||||||
|
class FlowHead(nn.Layer):
|
||||||
|
def __init__(self, input_dim=128, hidden_dim=256):
|
||||||
|
super(FlowHead, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2D(input_dim, hidden_dim, 3, padding=1)
|
||||||
|
self.conv2 = nn.Conv2D(hidden_dim, 2, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.conv2(self.relu(self.conv1(x)))
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateBlock(nn.Layer):
|
||||||
|
def __init__(self, hidden_dim: int = 128):
|
||||||
|
super(UpdateBlock, self).__init__()
|
||||||
|
|
||||||
|
self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
|
||||||
|
|
||||||
|
self.mask = nn.Sequential(
|
||||||
|
nn.Conv2D(hidden_dim, 256, 3, padding=1),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Conv2D(256, 64 * 9, 1, padding=0),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, image, coords):
|
||||||
|
mask = 0.25 * self.mask(image)
|
||||||
|
dflow = self.flow_head(image)
|
||||||
|
coords = coords + dflow
|
||||||
|
return mask, coords
|
||||||
|
|
||||||
|
|
||||||
|
def coords_grid(batch, ht, wd):
|
||||||
|
coords = paddle.meshgrid(paddle.arange(end=ht), paddle.arange(end=wd))
|
||||||
|
coords = paddle.stack(coords[::-1], axis=0).astype(dtype="float32")
|
||||||
|
return coords[None].tile([batch, 1, 1, 1])
|
||||||
|
|
||||||
|
|
||||||
|
def upflow8(flow, mode="bilinear"):
|
||||||
|
new_size = 8 * flow.shape[2], 8 * flow.shape[3]
|
||||||
|
return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
|
||||||
|
|
||||||
|
|
||||||
|
class OverlapPatchEmbed(nn.Layer):
|
||||||
|
"""Image to Patch Embedding"""
|
||||||
|
|
||||||
|
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
img_size = img_size if isinstance(img_size, tuple) else (img_size, img_size)
|
||||||
|
patch_size = (
|
||||||
|
patch_size if isinstance(patch_size, tuple) else (patch_size, patch_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
||||||
|
self.num_patches = self.H * self.W
|
||||||
|
|
||||||
|
self.proj = nn.Conv2D(
|
||||||
|
in_chans,
|
||||||
|
embed_dim,
|
||||||
|
patch_size,
|
||||||
|
stride=stride,
|
||||||
|
padding=(patch_size[0] // 2, patch_size[1] // 2),
|
||||||
|
)
|
||||||
|
self.norm = nn.LayerNorm(embed_dim)
|
||||||
|
|
||||||
|
self.apply(self._init_weights)
|
||||||
|
|
||||||
|
def _init_weights(self, m):
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
weight_init_(m, "trunc_normal_", std=0.02)
|
||||||
|
elif isinstance(m, nn.LayerNorm):
|
||||||
|
weight_init_(m, "Constant", value=1.0)
|
||||||
|
elif isinstance(m, nn.Conv2D):
|
||||||
|
weight_init_(
|
||||||
|
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.proj(x)
|
||||||
|
_, _, H, W = x.shape
|
||||||
|
x = x.flatten(2)
|
||||||
|
|
||||||
|
perm = list(range(x.ndim))
|
||||||
|
perm[1] = 2
|
||||||
|
perm[2] = 1
|
||||||
|
x = x.transpose(perm=perm)
|
||||||
|
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
return x, H, W
|
||||||
|
|
||||||
|
|
||||||
|
class GeoTr(nn.Layer):
|
||||||
|
def __init__(self):
|
||||||
|
super(GeoTr, self).__init__()
|
||||||
|
|
||||||
|
self.hidden_dim = hdim = 256
|
||||||
|
|
||||||
|
self.fnet = BasicEncoder(output_dim=hdim, norm_fn="instance")
|
||||||
|
|
||||||
|
self.encoder_block = [("encoder_block" + str(i)) for i in range(3)]
|
||||||
|
for i in self.encoder_block:
|
||||||
|
self.__setattr__(i, TransEncoder(2, hidden_dim=hdim))
|
||||||
|
|
||||||
|
self.down_layer = [("down_layer" + str(i)) for i in range(2)]
|
||||||
|
for i in self.down_layer:
|
||||||
|
self.__setattr__(i, nn.Conv2D(256, 256, 3, stride=2, padding=1))
|
||||||
|
|
||||||
|
self.decoder_block = [("decoder_block" + str(i)) for i in range(3)]
|
||||||
|
for i in self.decoder_block:
|
||||||
|
self.__setattr__(i, TransDecoder(2, hidden_dim=hdim))
|
||||||
|
|
||||||
|
self.up_layer = [("up_layer" + str(i)) for i in range(2)]
|
||||||
|
for i in self.up_layer:
|
||||||
|
self.__setattr__(
|
||||||
|
i, nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
self.query_embed = nn.Embedding(81, self.hidden_dim)
|
||||||
|
|
||||||
|
self.update_block = UpdateBlock(self.hidden_dim)
|
||||||
|
|
||||||
|
def initialize_flow(self, img):
|
||||||
|
N, _, H, W = img.shape
|
||||||
|
coodslar = coords_grid(N, H, W)
|
||||||
|
coords0 = coords_grid(N, H // 8, W // 8)
|
||||||
|
coords1 = coords_grid(N, H // 8, W // 8)
|
||||||
|
return coodslar, coords0, coords1
|
||||||
|
|
||||||
|
def upsample_flow(self, flow, mask):
|
||||||
|
N, _, H, W = flow.shape
|
||||||
|
|
||||||
|
mask = mask.reshape([N, 1, 9, 8, 8, H, W])
|
||||||
|
mask = F.softmax(mask, axis=2)
|
||||||
|
|
||||||
|
up_flow = F.unfold(8 * flow, [3, 3], paddings=1)
|
||||||
|
up_flow = up_flow.reshape([N, 2, 9, 1, 1, H, W])
|
||||||
|
|
||||||
|
up_flow = paddle.sum(mask * up_flow, axis=2)
|
||||||
|
up_flow = up_flow.transpose(perm=[0, 1, 4, 2, 5, 3])
|
||||||
|
|
||||||
|
return up_flow.reshape([N, 2, 8 * H, 8 * W])
|
||||||
|
|
||||||
|
def forward(self, image):
|
||||||
|
fmap = self.fnet(image)
|
||||||
|
fmap = F.relu(fmap)
|
||||||
|
|
||||||
|
fmap1 = self.__getattr__(self.encoder_block[0])(fmap)
|
||||||
|
fmap1d = self.__getattr__(self.down_layer[0])(fmap1)
|
||||||
|
fmap2 = self.__getattr__(self.encoder_block[1])(fmap1d)
|
||||||
|
fmap2d = self.__getattr__(self.down_layer[1])(fmap2)
|
||||||
|
fmap3 = self.__getattr__(self.encoder_block[2])(fmap2d)
|
||||||
|
|
||||||
|
query_embed0 = self.query_embed.weight.unsqueeze(1).tile([1, fmap3.shape[0], 1])
|
||||||
|
|
||||||
|
fmap3d_ = self.__getattr__(self.decoder_block[0])(fmap3, query_embed0)
|
||||||
|
fmap3du_ = (
|
||||||
|
self.__getattr__(self.up_layer[0])(fmap3d_)
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(perm=[2, 0, 1])
|
||||||
|
)
|
||||||
|
fmap2d_ = self.__getattr__(self.decoder_block[1])(fmap2, fmap3du_)
|
||||||
|
fmap2du_ = (
|
||||||
|
self.__getattr__(self.up_layer[1])(fmap2d_)
|
||||||
|
.flatten(2)
|
||||||
|
.transpose(perm=[2, 0, 1])
|
||||||
|
)
|
||||||
|
fmap_out = self.__getattr__(self.decoder_block[2])(fmap1, fmap2du_)
|
||||||
|
|
||||||
|
coodslar, coords0, coords1 = self.initialize_flow(image)
|
||||||
|
coords1 = coords1.detach()
|
||||||
|
mask, coords1 = self.update_block(fmap_out, coords1)
|
||||||
|
flow_up = self.upsample_flow(coords1 - coords0, mask)
|
||||||
|
bm_up = coodslar + flow_up
|
||||||
|
|
||||||
|
return bm_up
|
||||||
73
doc_dewarp/README.md
Normal file
73
doc_dewarp/README.md
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
# DocTrPP: DocTr++ in PaddlePaddle
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
This is a PaddlePaddle implementation of DocTr++. The original paper is [DocTr++: Deep Unrestricted Document Image Rectification](https://arxiv.org/abs/2304.08796). The original code is [here](https://github.com/fh2019ustc/DocTr-Plus).
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
## Requirements
|
||||||
|
|
||||||
|
You need to install the latest version of PaddlePaddle, which is done through this [link](https://www.paddlepaddle.org.cn/).
|
||||||
|
|
||||||
|
## Training
|
||||||
|
|
||||||
|
1. Data Preparation
|
||||||
|
|
||||||
|
To prepare datasets, refer to [doc3D](https://github.com/cvlab-stonybrook/doc3D-dataset).
|
||||||
|
|
||||||
|
2. Training
|
||||||
|
|
||||||
|
```shell
|
||||||
|
sh train.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
or
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
python train.py --img-size 288 \
|
||||||
|
--name "DocTr++" \
|
||||||
|
--batch-size 12 \
|
||||||
|
--lr 2.5e-5 \
|
||||||
|
--exist-ok \
|
||||||
|
--use-vdl
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Load Trained Model and Continue Training
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
python train.py --img-size 288 \
|
||||||
|
--name "DocTr++" \
|
||||||
|
--batch-size 12 \
|
||||||
|
--lr 2.5e-5 \
|
||||||
|
--resume "runs/train/DocTr++/weights/last.ckpt" \
|
||||||
|
--exist-ok \
|
||||||
|
--use-vdl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Test and Inference
|
||||||
|
|
||||||
|
Test the dewarp result on a single image:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python predict.py -i "crop/12_2 copy.png" -m runs/train/DocTr++/weights/best.ckpt -o 12.2.png
|
||||||
|
```
|
||||||
|

|
||||||
|
|
||||||
|
## Export to onnx
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install paddle2onnx
|
||||||
|
|
||||||
|
python export.py -m ./best.ckpt --format onnx
|
||||||
|
```
|
||||||
|
|
||||||
|
## Model Download
|
||||||
|
|
||||||
|
The trained model can be downloaded from [here](https://github.com/GreatV/DocTrPP/releases/download/v0.0.2/best.ckpt).
|
||||||
133
doc_dewarp/data_visualization.py
Normal file
133
doc_dewarp/data_visualization.py
Normal file
@@ -0,0 +1,133 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import hdf5storage as h5
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.functional as F
|
||||||
|
|
||||||
|
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_root",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="~/datasets/doc3d/",
|
||||||
|
help="Path to the downloaded dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--folder", nargs="?", type=int, default=1, help="Folder ID to read from"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="output.png",
|
||||||
|
help="Output filename for the image",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
root = os.path.expanduser(args.data_root)
|
||||||
|
folder = args.folder
|
||||||
|
dirname = os.path.join(root, "img", str(folder))
|
||||||
|
|
||||||
|
choices = [f for f in os.listdir(dirname) if "png" in f]
|
||||||
|
fname = random.choice(choices)
|
||||||
|
|
||||||
|
# Read Image
|
||||||
|
img_path = os.path.join(dirname, fname)
|
||||||
|
img = cv2.imread(img_path)
|
||||||
|
|
||||||
|
# Read 3D Coords
|
||||||
|
wc_path = os.path.join(root, "wc", str(folder), fname[:-3] + "exr")
|
||||||
|
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||||
|
# scale wc
|
||||||
|
# value obtained from the entire dataset
|
||||||
|
xmx, xmn, ymx, ymn, zmx, zmn = (
|
||||||
|
1.2539363,
|
||||||
|
-1.2442188,
|
||||||
|
1.2396319,
|
||||||
|
-1.2289206,
|
||||||
|
0.6436657,
|
||||||
|
-0.67492497,
|
||||||
|
)
|
||||||
|
wc[:, :, 0] = (wc[:, :, 0] - zmn) / (zmx - zmn)
|
||||||
|
wc[:, :, 1] = (wc[:, :, 1] - ymn) / (ymx - ymn)
|
||||||
|
wc[:, :, 2] = (wc[:, :, 2] - xmn) / (xmx - xmn)
|
||||||
|
|
||||||
|
# Read Backward Map
|
||||||
|
bm_path = os.path.join(root, "bm", str(folder), fname[:-3] + "mat")
|
||||||
|
bm = h5.loadmat(bm_path)["bm"]
|
||||||
|
|
||||||
|
# Read UV Map
|
||||||
|
uv_path = os.path.join(root, "uv", str(folder), fname[:-3] + "exr")
|
||||||
|
uv = cv2.imread(uv_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||||
|
|
||||||
|
# Read Depth Map
|
||||||
|
dmap_path = os.path.join(root, "dmap", str(folder), fname[:-3] + "exr")
|
||||||
|
dmap = cv2.imread(dmap_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)[:, :, 0]
|
||||||
|
# do some clipping and scaling to display it
|
||||||
|
dmap[dmap > 30.0] = 30
|
||||||
|
dmap = 1 - ((dmap - np.min(dmap)) / (np.max(dmap) - np.min(dmap)))
|
||||||
|
|
||||||
|
# Read Normal Map
|
||||||
|
norm_path = os.path.join(root, "norm", str(folder), fname[:-3] + "exr")
|
||||||
|
norm = cv2.imread(norm_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||||
|
|
||||||
|
# Read Albedo
|
||||||
|
alb_path = os.path.join(root, "alb", str(folder), fname[:-3] + "png")
|
||||||
|
alb = cv2.imread(alb_path)
|
||||||
|
|
||||||
|
# Read Checkerboard Image
|
||||||
|
recon_path = os.path.join(root, "recon", str(folder), fname[:-8] + "chess480001.png")
|
||||||
|
recon = cv2.imread(recon_path)
|
||||||
|
|
||||||
|
# Display image and GTs
|
||||||
|
|
||||||
|
# use the backward mapping to dewarp the image
|
||||||
|
# scale bm to -1.0 to 1.0
|
||||||
|
bm_ = bm / np.array([448, 448])
|
||||||
|
bm_ = (bm_ - 0.5) * 2
|
||||||
|
bm_ = np.reshape(bm_, (1, 448, 448, 2))
|
||||||
|
bm_ = paddle.to_tensor(bm_, dtype="float32")
|
||||||
|
img_ = alb.transpose((2, 0, 1)).astype(np.float32) / 255.0
|
||||||
|
img_ = np.expand_dims(img_, 0)
|
||||||
|
img_ = paddle.to_tensor(img_, dtype="float32")
|
||||||
|
uw = F.grid_sample(img_, bm_)
|
||||||
|
uw = uw[0].numpy().transpose((1, 2, 0))
|
||||||
|
|
||||||
|
f, axrr = plt.subplots(2, 5)
|
||||||
|
for ax in axrr:
|
||||||
|
for a in ax:
|
||||||
|
a.set_xticks([])
|
||||||
|
a.set_yticks([])
|
||||||
|
|
||||||
|
axrr[0][0].imshow(img)
|
||||||
|
axrr[0][0].title.set_text("image")
|
||||||
|
axrr[0][1].imshow(wc)
|
||||||
|
axrr[0][1].title.set_text("3D coords")
|
||||||
|
axrr[0][2].imshow(bm[:, :, 0])
|
||||||
|
axrr[0][2].title.set_text("bm 0")
|
||||||
|
axrr[0][3].imshow(bm[:, :, 1])
|
||||||
|
axrr[0][3].title.set_text("bm 1")
|
||||||
|
if uv is None:
|
||||||
|
uv = np.zeros_like(img)
|
||||||
|
axrr[0][4].imshow(uv)
|
||||||
|
axrr[0][4].title.set_text("uv map")
|
||||||
|
axrr[1][0].imshow(dmap)
|
||||||
|
axrr[1][0].title.set_text("depth map")
|
||||||
|
axrr[1][1].imshow(norm)
|
||||||
|
axrr[1][1].title.set_text("normal map")
|
||||||
|
axrr[1][2].imshow(alb)
|
||||||
|
axrr[1][2].title.set_text("albedo")
|
||||||
|
axrr[1][3].imshow(recon)
|
||||||
|
axrr[1][3].title.set_text("checkerboard")
|
||||||
|
axrr[1][4].imshow(uw)
|
||||||
|
axrr[1][4].title.set_text("gt unwarped")
|
||||||
|
plt.tight_layout()
|
||||||
|
plt.savefig(args.output)
|
||||||
26
doc_dewarp/dewarp.py
Normal file
26
doc_dewarp/dewarp.py
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import cv2
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from .GeoTr import GeoTr
|
||||||
|
from .utils import to_tensor, to_image
|
||||||
|
|
||||||
|
|
||||||
|
def dewarp_image(image):
|
||||||
|
model_path = "model/dewarp_model/best.ckpt"
|
||||||
|
|
||||||
|
checkpoint = paddle.load(model_path)
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
model = GeoTr()
|
||||||
|
model.set_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
img = cv2.resize(image, (288, 288))
|
||||||
|
x = to_tensor(img)
|
||||||
|
y = to_tensor(image)
|
||||||
|
bm = model(x)
|
||||||
|
bm = paddle.nn.functional.interpolate(
|
||||||
|
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||||
|
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||||
|
return to_image(out)
|
||||||
161
doc_dewarp/doc/download_dataset.sh
Normal file
161
doc_dewarp/doc/download_dataset.sh
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/img_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/wc_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/bm_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/uv_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/alb_8.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/recon_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/norm_21.zip"
|
||||||
|
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_1.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_2.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_3.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_4.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_5.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_6.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_7.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_8.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_9.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_10.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_11.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_12.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_13.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_14.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_15.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_16.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_17.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_18.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_19.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_20.zip"
|
||||||
|
wget "https://bj.bcebos.com/paddleseg/paddleseg/datasets/doc3d/dmap_21.zip"
|
||||||
BIN
doc_dewarp/doc/imgs/document_image_rectification.jpg
Normal file
BIN
doc_dewarp/doc/imgs/document_image_rectification.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 76 KiB |
129
doc_dewarp/doc3d_dataset.py
Normal file
129
doc_dewarp/doc3d_dataset.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import collections
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import albumentations as A
|
||||||
|
import cv2
|
||||||
|
import hdf5storage as h5
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle import io
|
||||||
|
|
||||||
|
# Set random seed
|
||||||
|
random.seed(12345678)
|
||||||
|
|
||||||
|
|
||||||
|
class Doc3dDataset(io.Dataset):
|
||||||
|
def __init__(self, root, split="train", is_augment=False, image_size=512):
|
||||||
|
self.root = os.path.expanduser(root)
|
||||||
|
|
||||||
|
self.split = split
|
||||||
|
self.is_augment = is_augment
|
||||||
|
|
||||||
|
self.files = collections.defaultdict(list)
|
||||||
|
|
||||||
|
self.image_size = (
|
||||||
|
image_size if isinstance(image_size, tuple) else (image_size, image_size)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Augmentation
|
||||||
|
self.augmentation = A.Compose(
|
||||||
|
[
|
||||||
|
A.ColorJitter(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
for split in ["train", "val"]:
|
||||||
|
path = os.path.join(self.root, split + ".txt")
|
||||||
|
file_list = []
|
||||||
|
with open(path, "r") as file:
|
||||||
|
file_list = [file_id.rstrip() for file_id in file.readlines()]
|
||||||
|
self.files[split] = file_list
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.files[self.split])
|
||||||
|
|
||||||
|
def __getitem__(self, index):
|
||||||
|
image_name = self.files[self.split][index]
|
||||||
|
|
||||||
|
# Read image
|
||||||
|
image_path = os.path.join(self.root, "img", image_name + ".png")
|
||||||
|
image = cv2.imread(image_path)
|
||||||
|
|
||||||
|
# Read 3D Coordinates
|
||||||
|
wc_path = os.path.join(self.root, "wc", image_name + ".exr")
|
||||||
|
wc = cv2.imread(wc_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH)
|
||||||
|
|
||||||
|
# Read backward map
|
||||||
|
bm_path = os.path.join(self.root, "bm", image_name + ".mat")
|
||||||
|
bm = h5.loadmat(bm_path)["bm"]
|
||||||
|
|
||||||
|
image, bm = self.transform(wc, bm, image)
|
||||||
|
|
||||||
|
return image, bm
|
||||||
|
|
||||||
|
def tight_crop(self, wc: np.ndarray):
|
||||||
|
mask = ((wc[:, :, 0] != 0) & (wc[:, :, 1] != 0) & (wc[:, :, 2] != 0)).astype(
|
||||||
|
np.uint8
|
||||||
|
)
|
||||||
|
mask_size = mask.shape
|
||||||
|
[y, x] = mask.nonzero()
|
||||||
|
min_x = min(x)
|
||||||
|
max_x = max(x)
|
||||||
|
min_y = min(y)
|
||||||
|
max_y = max(y)
|
||||||
|
|
||||||
|
wc = wc[min_y : max_y + 1, min_x : max_x + 1, :]
|
||||||
|
s = 10
|
||||||
|
wc = np.pad(wc, ((s, s), (s, s), (0, 0)), "constant")
|
||||||
|
|
||||||
|
cx1 = random.randint(0, 2 * s)
|
||||||
|
cx2 = random.randint(0, 2 * s) + 1
|
||||||
|
cy1 = random.randint(0, 2 * s)
|
||||||
|
cy2 = random.randint(0, 2 * s) + 1
|
||||||
|
|
||||||
|
wc = wc[cy1:-cy2, cx1:-cx2, :]
|
||||||
|
|
||||||
|
top: int = min_y - s + cy1
|
||||||
|
bottom: int = mask_size[0] - max_y - s + cy2
|
||||||
|
left: int = min_x - s + cx1
|
||||||
|
right: int = mask_size[1] - max_x - s + cx2
|
||||||
|
|
||||||
|
top = np.clip(top, 0, mask_size[0])
|
||||||
|
bottom = np.clip(bottom, 1, mask_size[0] - 1)
|
||||||
|
left = np.clip(left, 0, mask_size[1])
|
||||||
|
right = np.clip(right, 1, mask_size[1] - 1)
|
||||||
|
|
||||||
|
return wc, top, bottom, left, right
|
||||||
|
|
||||||
|
def transform(self, wc, bm, img):
|
||||||
|
wc, top, bottom, left, right = self.tight_crop(wc)
|
||||||
|
|
||||||
|
img = img[top:-bottom, left:-right, :]
|
||||||
|
|
||||||
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
|
||||||
|
if self.is_augment:
|
||||||
|
img = self.augmentation(image=img)["image"]
|
||||||
|
|
||||||
|
# resize image
|
||||||
|
img = cv2.resize(img, self.image_size)
|
||||||
|
img = img.astype(np.float32) / 255.0
|
||||||
|
img = img.transpose(2, 0, 1)
|
||||||
|
|
||||||
|
# resize bm
|
||||||
|
bm = bm.astype(np.float32)
|
||||||
|
bm[:, :, 1] = bm[:, :, 1] - top
|
||||||
|
bm[:, :, 0] = bm[:, :, 0] - left
|
||||||
|
bm = bm / np.array([448.0 - left - right, 448.0 - top - bottom])
|
||||||
|
bm0 = cv2.resize(bm[:, :, 0], (self.image_size[0], self.image_size[1]))
|
||||||
|
bm1 = cv2.resize(bm[:, :, 1], (self.image_size[0], self.image_size[1]))
|
||||||
|
bm0 = bm0 * self.image_size[0]
|
||||||
|
bm1 = bm1 * self.image_size[1]
|
||||||
|
|
||||||
|
bm = np.stack([bm0, bm1], axis=-1)
|
||||||
|
|
||||||
|
img = paddle.to_tensor(img).astype(dtype="float32")
|
||||||
|
bm = paddle.to_tensor(bm).astype(dtype="float32")
|
||||||
|
|
||||||
|
return img, bm
|
||||||
66
doc_dewarp/export.py
Normal file
66
doc_dewarp/export.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from GeoTr import GeoTr
|
||||||
|
|
||||||
|
|
||||||
|
def export(args):
|
||||||
|
model_path = args.model
|
||||||
|
imgsz = args.imgsz
|
||||||
|
format = args.format
|
||||||
|
|
||||||
|
model = GeoTr()
|
||||||
|
checkpoint = paddle.load(model_path)
|
||||||
|
model.set_state_dict(checkpoint["model"])
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
dirname = os.path.dirname(model_path)
|
||||||
|
if format == "static" or format == "onnx":
|
||||||
|
model = paddle.jit.to_static(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec(shape=[1, 3, imgsz, imgsz], dtype="float32")
|
||||||
|
],
|
||||||
|
full_graph=True,
|
||||||
|
)
|
||||||
|
paddle.jit.save(model, os.path.join(dirname, "model"))
|
||||||
|
|
||||||
|
if format == "onnx":
|
||||||
|
onnx_path = os.path.join(dirname, "model.onnx")
|
||||||
|
os.system(
|
||||||
|
f"paddle2onnx --model_dir {dirname}"
|
||||||
|
" --model_filename model.pdmodel"
|
||||||
|
" --params_filename model.pdiparams"
|
||||||
|
f" --save_file {onnx_path}"
|
||||||
|
" --opset_version 11"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="export model")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The path of model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--imgsz", type=int, default=288, help="The size of input image"
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--format",
|
||||||
|
type=str,
|
||||||
|
default="static",
|
||||||
|
help="The format of exported model, which can be static or onnx",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
export(args)
|
||||||
110
doc_dewarp/extractor.py
Normal file
110
doc_dewarp/extractor.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
import paddle.nn as nn
|
||||||
|
|
||||||
|
from .weight_init import weight_init_
|
||||||
|
|
||||||
|
|
||||||
|
class ResidualBlock(nn.Layer):
|
||||||
|
"""Residual Block with custom normalization."""
|
||||||
|
|
||||||
|
def __init__(self, in_planes, planes, norm_fn="group", stride=1):
|
||||||
|
super(ResidualBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2D(in_planes, planes, 3, padding=1, stride=stride)
|
||||||
|
self.conv2 = nn.Conv2D(planes, planes, 3, padding=1)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
if norm_fn == "group":
|
||||||
|
num_groups = planes // 8
|
||||||
|
self.norm1 = nn.GroupNorm(num_groups, planes)
|
||||||
|
self.norm2 = nn.GroupNorm(num_groups, planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.GroupNorm(num_groups, planes)
|
||||||
|
elif norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2D(planes)
|
||||||
|
self.norm2 = nn.BatchNorm2D(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.BatchNorm2D(planes)
|
||||||
|
elif norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2D(planes)
|
||||||
|
self.norm2 = nn.InstanceNorm2D(planes)
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.InstanceNorm2D(planes)
|
||||||
|
elif norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
self.norm2 = nn.Sequential()
|
||||||
|
if not stride == 1:
|
||||||
|
self.norm3 = nn.Sequential()
|
||||||
|
|
||||||
|
if stride == 1:
|
||||||
|
self.downsample = None
|
||||||
|
else:
|
||||||
|
self.downsample = nn.Sequential(
|
||||||
|
nn.Conv2D(in_planes, planes, 1, stride=stride), self.norm3
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
y = x
|
||||||
|
y = self.relu(self.norm1(self.conv1(y)))
|
||||||
|
y = self.relu(self.norm2(self.conv2(y)))
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
x = self.downsample(x)
|
||||||
|
|
||||||
|
return self.relu(x + y)
|
||||||
|
|
||||||
|
|
||||||
|
class BasicEncoder(nn.Layer):
|
||||||
|
"""Basic Encoder with custom normalization."""
|
||||||
|
|
||||||
|
def __init__(self, output_dim=128, norm_fn="batch"):
|
||||||
|
super(BasicEncoder, self).__init__()
|
||||||
|
|
||||||
|
self.norm_fn = norm_fn
|
||||||
|
|
||||||
|
if self.norm_fn == "group":
|
||||||
|
self.norm1 = nn.GroupNorm(8, 64)
|
||||||
|
elif self.norm_fn == "batch":
|
||||||
|
self.norm1 = nn.BatchNorm2D(64)
|
||||||
|
elif self.norm_fn == "instance":
|
||||||
|
self.norm1 = nn.InstanceNorm2D(64)
|
||||||
|
elif self.norm_fn == "none":
|
||||||
|
self.norm1 = nn.Sequential()
|
||||||
|
|
||||||
|
self.conv1 = nn.Conv2D(3, 64, 7, stride=2, padding=3)
|
||||||
|
self.relu1 = nn.ReLU()
|
||||||
|
|
||||||
|
self.in_planes = 64
|
||||||
|
self.layer1 = self._make_layer(64, stride=1)
|
||||||
|
self.layer2 = self._make_layer(128, stride=2)
|
||||||
|
self.layer3 = self._make_layer(192, stride=2)
|
||||||
|
|
||||||
|
self.conv2 = nn.Conv2D(192, output_dim, 1)
|
||||||
|
|
||||||
|
for m in self.sublayers():
|
||||||
|
if isinstance(m, nn.Conv2D):
|
||||||
|
weight_init_(
|
||||||
|
m.weight, "kaiming_normal_", mode="fan_out", nonlinearity="relu"
|
||||||
|
)
|
||||||
|
elif isinstance(m, (nn.BatchNorm2D, nn.InstanceNorm2D, nn.GroupNorm)):
|
||||||
|
weight_init_(m, "Constant", value=1, bias_value=0.0)
|
||||||
|
|
||||||
|
def _make_layer(self, dim, stride=1):
|
||||||
|
layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
|
||||||
|
layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
|
||||||
|
layers = layer1, layer2
|
||||||
|
|
||||||
|
self.in_planes = dim
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.norm1(x)
|
||||||
|
x = self.relu1(x)
|
||||||
|
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
|
||||||
|
x = self.conv2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
48
doc_dewarp/plots.py
Normal file
48
doc_dewarp/plots.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import paddle.optimizer as optim
|
||||||
|
|
||||||
|
from GeoTr import GeoTr
|
||||||
|
|
||||||
|
|
||||||
|
def plot_lr_scheduler(optimizer, scheduler, epochs=65, save_dir=""):
|
||||||
|
"""
|
||||||
|
Plot the learning rate scheduler
|
||||||
|
"""
|
||||||
|
|
||||||
|
optimizer = copy.copy(optimizer)
|
||||||
|
scheduler = copy.copy(scheduler)
|
||||||
|
|
||||||
|
lr = []
|
||||||
|
for _ in range(epochs):
|
||||||
|
for _ in range(30):
|
||||||
|
lr.append(scheduler.get_lr())
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
|
||||||
|
epoch = [float(i) / 30.0 for i in range(len(lr))]
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
plt.plot(epoch, lr, ".-", label="Learning Rate")
|
||||||
|
plt.xlabel("epoch")
|
||||||
|
plt.ylabel("Learning Rate")
|
||||||
|
plt.title("Learning Rate Scheduler")
|
||||||
|
plt.savefig(os.path.join(save_dir, "lr_scheduler.png"), dpi=300)
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
model = GeoTr()
|
||||||
|
|
||||||
|
schaduler = optim.lr.OneCycleLR(
|
||||||
|
max_learning_rate=1e-4,
|
||||||
|
total_steps=1950,
|
||||||
|
phase_pct=0.1,
|
||||||
|
end_learning_rate=1e-4 / 2.5e5,
|
||||||
|
)
|
||||||
|
optimizer = optim.AdamW(learning_rate=schaduler, parameters=model.parameters())
|
||||||
|
plot_lr_scheduler(
|
||||||
|
scheduler=schaduler, optimizer=optimizer, epochs=65, save_dir="./"
|
||||||
|
)
|
||||||
124
doc_dewarp/position_encoding.py
Normal file
124
doc_dewarp/position_encoding.py
Normal file
@@ -0,0 +1,124 @@
|
|||||||
|
import math
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.nn.initializer as init
|
||||||
|
|
||||||
|
|
||||||
|
class NestedTensor(object):
|
||||||
|
def __init__(self, tensors, mask: Optional[paddle.Tensor]):
|
||||||
|
self.tensors = tensors
|
||||||
|
self.mask = mask
|
||||||
|
|
||||||
|
def decompose(self):
|
||||||
|
return self.tensors, self.mask
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return str(self.tensors)
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingSine(nn.Layer):
|
||||||
|
"""
|
||||||
|
This is a more standard version of the position embedding, very similar to the one
|
||||||
|
used by the Attention is all you need paper, generalized to work on images.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
self.num_pos_feats = num_pos_feats
|
||||||
|
self.temperature = temperature
|
||||||
|
self.normalize = normalize
|
||||||
|
|
||||||
|
if scale is not None and normalize is False:
|
||||||
|
raise ValueError("normalize should be True if scale is passed")
|
||||||
|
|
||||||
|
if scale is None:
|
||||||
|
scale = 2 * math.pi
|
||||||
|
|
||||||
|
self.scale = scale
|
||||||
|
|
||||||
|
def forward(self, mask):
|
||||||
|
assert mask is not None
|
||||||
|
|
||||||
|
y_embed = mask.cumsum(axis=1, dtype="float32")
|
||||||
|
x_embed = mask.cumsum(axis=2, dtype="float32")
|
||||||
|
|
||||||
|
if self.normalize:
|
||||||
|
eps = 1e-06
|
||||||
|
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||||
|
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||||
|
|
||||||
|
dim_t = paddle.arange(end=self.num_pos_feats, dtype="float32")
|
||||||
|
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||||
|
|
||||||
|
pos_x = x_embed[:, :, :, None] / dim_t
|
||||||
|
pos_y = y_embed[:, :, :, None] / dim_t
|
||||||
|
|
||||||
|
pos_x = paddle.stack(
|
||||||
|
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), axis=4
|
||||||
|
).flatten(3)
|
||||||
|
pos_y = paddle.stack(
|
||||||
|
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), axis=4
|
||||||
|
).flatten(3)
|
||||||
|
|
||||||
|
pos = paddle.concat((pos_y, pos_x), axis=3).transpose(perm=[0, 3, 1, 2])
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
class PositionEmbeddingLearned(nn.Layer):
|
||||||
|
"""
|
||||||
|
Absolute pos embedding, learned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_pos_feats=256):
|
||||||
|
super().__init__()
|
||||||
|
self.row_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.col_embed = nn.Embedding(50, num_pos_feats)
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def reset_parameters(self):
|
||||||
|
init_Constant = init.Uniform()
|
||||||
|
init_Constant(self.row_embed.weight)
|
||||||
|
init_Constant(self.col_embed.weight)
|
||||||
|
|
||||||
|
def forward(self, tensor_list: NestedTensor):
|
||||||
|
x = tensor_list.tensors
|
||||||
|
|
||||||
|
h, w = x.shape[-2:]
|
||||||
|
|
||||||
|
i = paddle.arange(end=w)
|
||||||
|
j = paddle.arange(end=h)
|
||||||
|
|
||||||
|
x_emb = self.col_embed(i)
|
||||||
|
y_emb = self.row_embed(j)
|
||||||
|
|
||||||
|
pos = (
|
||||||
|
paddle.concat(
|
||||||
|
[
|
||||||
|
x_emb.unsqueeze(0).tile([h, 1, 1]),
|
||||||
|
y_emb.unsqueeze(1).tile([1, w, 1]),
|
||||||
|
],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
.transpose([2, 0, 1])
|
||||||
|
.unsqueeze(0)
|
||||||
|
.tile([x.shape[0], 1, 1, 1])
|
||||||
|
)
|
||||||
|
|
||||||
|
return pos
|
||||||
|
|
||||||
|
|
||||||
|
def build_position_encoding(hidden_dim=512, position_embedding="sine"):
|
||||||
|
N_steps = hidden_dim // 2
|
||||||
|
|
||||||
|
if position_embedding in ("v2", "sine"):
|
||||||
|
position_embedding = PositionEmbeddingSine(N_steps, normalize=True)
|
||||||
|
elif position_embedding in ("v3", "learned"):
|
||||||
|
position_embedding = PositionEmbeddingLearned(N_steps)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"not supported {position_embedding}")
|
||||||
|
return position_embedding
|
||||||
69
doc_dewarp/predict.py
Normal file
69
doc_dewarp/predict.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
import argparse
|
||||||
|
|
||||||
|
import cv2
|
||||||
|
import paddle
|
||||||
|
|
||||||
|
from GeoTr import GeoTr
|
||||||
|
from utils import to_image, to_tensor
|
||||||
|
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
image_path = args.image
|
||||||
|
model_path = args.model
|
||||||
|
output_path = args.output
|
||||||
|
|
||||||
|
checkpoint = paddle.load(model_path)
|
||||||
|
state_dict = checkpoint["model"]
|
||||||
|
model = GeoTr()
|
||||||
|
model.set_state_dict(state_dict)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
img_org = cv2.imread(image_path)
|
||||||
|
img = cv2.resize(img_org, (288, 288))
|
||||||
|
x = to_tensor(img)
|
||||||
|
y = to_tensor(img_org)
|
||||||
|
bm = model(x)
|
||||||
|
bm = paddle.nn.functional.interpolate(
|
||||||
|
bm, y.shape[2:], mode="bilinear", align_corners=False
|
||||||
|
)
|
||||||
|
bm_nhwc = bm.transpose([0, 2, 3, 1])
|
||||||
|
out = paddle.nn.functional.grid_sample(y, (bm_nhwc / 288 - 0.5) * 2)
|
||||||
|
out_image = to_image(out)
|
||||||
|
cv2.imwrite(output_path, out_image)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="predict")
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--image",
|
||||||
|
"-i",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The path of image",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
"-m",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The path of model",
|
||||||
|
)
|
||||||
|
|
||||||
|
parser.add_argument(
|
||||||
|
"--output",
|
||||||
|
"-o",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="",
|
||||||
|
help="The path of output",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
run(args)
|
||||||
7
doc_dewarp/requirements.txt
Normal file
7
doc_dewarp/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
hdf5storage
|
||||||
|
loguru
|
||||||
|
numpy
|
||||||
|
scipy
|
||||||
|
opencv-python
|
||||||
|
matplotlib
|
||||||
|
albumentations
|
||||||
57
doc_dewarp/split_dataset.py
Normal file
57
doc_dewarp/split_dataset.py
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import argparse
|
||||||
|
import glob
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
random.seed(1234567)
|
||||||
|
|
||||||
|
|
||||||
|
def run(args):
|
||||||
|
data_root = os.path.expanduser(args.data_root)
|
||||||
|
ratio = args.train_ratio
|
||||||
|
|
||||||
|
data_path = os.path.join(data_root, "img", "*", "*.png")
|
||||||
|
img_list = glob.glob(data_path, recursive=True)
|
||||||
|
sorted(img_list)
|
||||||
|
random.shuffle(img_list)
|
||||||
|
|
||||||
|
train_size = int(len(img_list) * ratio)
|
||||||
|
|
||||||
|
train_text_path = os.path.join(data_root, "train.txt")
|
||||||
|
with open(train_text_path, "w") as file:
|
||||||
|
for item in img_list[:train_size]:
|
||||||
|
parts = Path(item).parts
|
||||||
|
item = os.path.join(parts[-2], parts[-1])
|
||||||
|
file.write("%s\n" % item.split(".png")[0])
|
||||||
|
|
||||||
|
val_text_path = os.path.join(data_root, "val.txt")
|
||||||
|
with open(val_text_path, "w") as file:
|
||||||
|
for item in img_list[train_size:]:
|
||||||
|
parts = Path(item).parts
|
||||||
|
item = os.path.join(parts[-2], parts[-1])
|
||||||
|
file.write("%s\n" % item.split(".png")[0])
|
||||||
|
|
||||||
|
logger.info(f"TRAIN LABEL: {train_text_path}")
|
||||||
|
logger.info(f"VAL LABEL: {val_text_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"--data_root",
|
||||||
|
type=str,
|
||||||
|
default="~/datasets/doc3d",
|
||||||
|
help="Data path to load data",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
logger.info(args)
|
||||||
|
|
||||||
|
run(args)
|
||||||
381
doc_dewarp/train.py
Normal file
381
doc_dewarp/train.py
Normal file
@@ -0,0 +1,381 @@
|
|||||||
|
import argparse
|
||||||
|
import inspect
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn as nn
|
||||||
|
import paddle.optimizer as optim
|
||||||
|
from loguru import logger
|
||||||
|
from paddle.io import DataLoader
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
from paddle_msssim import ms_ssim, ssim
|
||||||
|
|
||||||
|
from doc3d_dataset import Doc3dDataset
|
||||||
|
from GeoTr import GeoTr
|
||||||
|
from utils import to_image
|
||||||
|
|
||||||
|
FILE = Path(__file__).resolve()
|
||||||
|
ROOT = FILE.parents[0]
|
||||||
|
RANK = int(os.getenv("RANK", -1))
|
||||||
|
|
||||||
|
|
||||||
|
def init_seeds(seed=0, deterministic=False):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
paddle.seed(seed)
|
||||||
|
if deterministic:
|
||||||
|
os.environ["FLAGS_cudnn_deterministic"] = "1"
|
||||||
|
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
||||||
|
os.environ["PYTHONHASHSEED"] = str(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def colorstr(*input):
|
||||||
|
# Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code,
|
||||||
|
# i.e. colorstr('blue', 'hello world')
|
||||||
|
|
||||||
|
*args, string = (
|
||||||
|
input if len(input) > 1 else ("blue", "bold", input[0])
|
||||||
|
) # color arguments, string
|
||||||
|
|
||||||
|
colors = {
|
||||||
|
"black": "\033[30m", # basic colors
|
||||||
|
"red": "\033[31m",
|
||||||
|
"green": "\033[32m",
|
||||||
|
"yellow": "\033[33m",
|
||||||
|
"blue": "\033[34m",
|
||||||
|
"magenta": "\033[35m",
|
||||||
|
"cyan": "\033[36m",
|
||||||
|
"white": "\033[37m",
|
||||||
|
"bright_black": "\033[90m", # bright colors
|
||||||
|
"bright_red": "\033[91m",
|
||||||
|
"bright_green": "\033[92m",
|
||||||
|
"bright_yellow": "\033[93m",
|
||||||
|
"bright_blue": "\033[94m",
|
||||||
|
"bright_magenta": "\033[95m",
|
||||||
|
"bright_cyan": "\033[96m",
|
||||||
|
"bright_white": "\033[97m",
|
||||||
|
"end": "\033[0m", # misc
|
||||||
|
"bold": "\033[1m",
|
||||||
|
"underline": "\033[4m",
|
||||||
|
}
|
||||||
|
|
||||||
|
return "".join(colors[x] for x in args) + f"{string}" + colors["end"]
|
||||||
|
|
||||||
|
|
||||||
|
def print_args(args: Optional[dict] = None, show_file=True, show_func=False):
|
||||||
|
# Print function arguments (optional args dict)
|
||||||
|
x = inspect.currentframe().f_back # previous frame
|
||||||
|
file, _, func, _, _ = inspect.getframeinfo(x)
|
||||||
|
if args is None: # get args automatically
|
||||||
|
args, _, _, frm = inspect.getargvalues(x)
|
||||||
|
args = {k: v for k, v in frm.items() if k in args}
|
||||||
|
try:
|
||||||
|
file = Path(file).resolve().relative_to(ROOT).with_suffix("")
|
||||||
|
except ValueError:
|
||||||
|
file = Path(file).stem
|
||||||
|
s = (f"{file}: " if show_file else "") + (f"{func}: " if show_func else "")
|
||||||
|
logger.info(colorstr(s) + ", ".join(f"{k}={v}" for k, v in args.items()))
|
||||||
|
|
||||||
|
|
||||||
|
def increment_path(path, exist_ok=False, sep="", mkdir=False):
|
||||||
|
# Increment file or directory path,
|
||||||
|
# i.e. runs/exp --> runs/exp{sep}2, runs/exp{sep}3, ... etc.
|
||||||
|
path = Path(path) # os-agnostic
|
||||||
|
if path.exists() and not exist_ok:
|
||||||
|
path, suffix = (
|
||||||
|
(path.with_suffix(""), path.suffix) if path.is_file() else (path, "")
|
||||||
|
)
|
||||||
|
|
||||||
|
for n in range(2, 9999):
|
||||||
|
p = f"{path}{sep}{n}{suffix}" # increment path
|
||||||
|
if not os.path.exists(p):
|
||||||
|
break
|
||||||
|
path = Path(p)
|
||||||
|
|
||||||
|
if mkdir:
|
||||||
|
path.mkdir(parents=True, exist_ok=True) # make directory
|
||||||
|
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def train(args):
|
||||||
|
save_dir = Path(args.save_dir)
|
||||||
|
|
||||||
|
use_vdl = args.use_vdl
|
||||||
|
if use_vdl:
|
||||||
|
from visualdl import LogWriter
|
||||||
|
|
||||||
|
log_dir = save_dir / "vdl"
|
||||||
|
vdl_writer = LogWriter(str(log_dir))
|
||||||
|
|
||||||
|
# Directories
|
||||||
|
weights_dir = save_dir / "weights"
|
||||||
|
weights_dir.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
last = weights_dir / "last.ckpt"
|
||||||
|
best = weights_dir / "best.ckpt"
|
||||||
|
|
||||||
|
# Hyperparameters
|
||||||
|
|
||||||
|
# Config
|
||||||
|
init_seeds(args.seed)
|
||||||
|
|
||||||
|
# Train loader
|
||||||
|
train_dataset = Doc3dDataset(
|
||||||
|
args.data_root,
|
||||||
|
split="train",
|
||||||
|
is_augment=True,
|
||||||
|
image_size=args.img_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
train_loader = DataLoader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=args.batch_size,
|
||||||
|
shuffle=True,
|
||||||
|
num_workers=args.workers,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation loader
|
||||||
|
val_dataset = Doc3dDataset(
|
||||||
|
args.data_root,
|
||||||
|
split="val",
|
||||||
|
is_augment=False,
|
||||||
|
image_size=args.img_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
val_loader = DataLoader(
|
||||||
|
val_dataset, batch_size=args.batch_size, num_workers=args.workers
|
||||||
|
)
|
||||||
|
|
||||||
|
# Model
|
||||||
|
model = GeoTr()
|
||||||
|
|
||||||
|
if use_vdl:
|
||||||
|
vdl_writer.add_graph(
|
||||||
|
model,
|
||||||
|
input_spec=[
|
||||||
|
paddle.static.InputSpec([1, 3, args.img_size, args.img_size], "float32")
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# Data Parallel Mode
|
||||||
|
if RANK == -1 and paddle.device.cuda.device_count() > 1:
|
||||||
|
model = paddle.DataParallel(model)
|
||||||
|
|
||||||
|
# Scheduler
|
||||||
|
scheduler = optim.lr.OneCycleLR(
|
||||||
|
max_learning_rate=args.lr,
|
||||||
|
total_steps=args.epochs * len(train_loader),
|
||||||
|
phase_pct=0.1,
|
||||||
|
end_learning_rate=args.lr / 2.5e5,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer = optim.AdamW(
|
||||||
|
learning_rate=scheduler,
|
||||||
|
parameters=model.parameters(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# loss function
|
||||||
|
l1_loss_fn = nn.L1Loss()
|
||||||
|
mse_loss_fn = nn.MSELoss()
|
||||||
|
|
||||||
|
# Resume
|
||||||
|
best_fitness, start_epoch = 0.0, 0
|
||||||
|
if args.resume:
|
||||||
|
ckpt = paddle.load(args.resume)
|
||||||
|
model.set_state_dict(ckpt["model"])
|
||||||
|
optimizer.set_state_dict(ckpt["optimizer"])
|
||||||
|
scheduler.set_state_dict(ckpt["scheduler"])
|
||||||
|
best_fitness = ckpt["best_fitness"]
|
||||||
|
start_epoch = ckpt["epoch"] + 1
|
||||||
|
|
||||||
|
# Train
|
||||||
|
for epoch in range(start_epoch, args.epochs):
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
for i, (img, target) in enumerate(train_loader):
|
||||||
|
img = paddle.to_tensor(img) # NCHW
|
||||||
|
target = paddle.to_tensor(target) # NHWC
|
||||||
|
|
||||||
|
pred = model(img) # NCHW
|
||||||
|
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||||
|
|
||||||
|
loss = l1_loss_fn(pred_nhwc, target)
|
||||||
|
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||||
|
|
||||||
|
if use_vdl:
|
||||||
|
vdl_writer.add_scalar(
|
||||||
|
"Train/L1 Loss", float(loss), epoch * len(train_loader) + i
|
||||||
|
)
|
||||||
|
vdl_writer.add_scalar(
|
||||||
|
"Train/MSE Loss", float(mse_loss), epoch * len(train_loader) + i
|
||||||
|
)
|
||||||
|
vdl_writer.add_scalar(
|
||||||
|
"Train/Learning Rate",
|
||||||
|
float(scheduler.get_lr()),
|
||||||
|
epoch * len(train_loader) + i,
|
||||||
|
)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
scheduler.step()
|
||||||
|
optimizer.clear_grad()
|
||||||
|
|
||||||
|
if i % 10 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"[TRAIN MODE] Epoch: {epoch}, Iter: {i}, L1 Loss: {float(loss)}, "
|
||||||
|
f"MSE Loss: {float(mse_loss)}, LR: {float(scheduler.get_lr())}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validation
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
avg_ssim = paddle.zeros([])
|
||||||
|
avg_ms_ssim = paddle.zeros([])
|
||||||
|
avg_l1_loss = paddle.zeros([])
|
||||||
|
avg_mse_loss = paddle.zeros([])
|
||||||
|
|
||||||
|
for i, (img, target) in enumerate(val_loader):
|
||||||
|
img = paddle.to_tensor(img)
|
||||||
|
target = paddle.to_tensor(target)
|
||||||
|
|
||||||
|
pred = model(img)
|
||||||
|
pred_nhwc = pred.transpose([0, 2, 3, 1])
|
||||||
|
|
||||||
|
# predict image
|
||||||
|
out = F.grid_sample(img, (pred_nhwc / args.img_size - 0.5) * 2)
|
||||||
|
out_gt = F.grid_sample(img, (target / args.img_size - 0.5) * 2)
|
||||||
|
|
||||||
|
# calculate ssim
|
||||||
|
ssim_val = ssim(out, out_gt, data_range=1.0)
|
||||||
|
ms_ssim_val = ms_ssim(out, out_gt, data_range=1.0)
|
||||||
|
|
||||||
|
loss = l1_loss_fn(pred_nhwc, target)
|
||||||
|
mse_loss = mse_loss_fn(pred_nhwc, target)
|
||||||
|
|
||||||
|
# calculate fitness
|
||||||
|
avg_ssim += ssim_val
|
||||||
|
avg_ms_ssim += ms_ssim_val
|
||||||
|
avg_l1_loss += loss
|
||||||
|
avg_mse_loss += mse_loss
|
||||||
|
|
||||||
|
if i % 10 == 0:
|
||||||
|
logger.info(
|
||||||
|
f"[VAL MODE] Epoch: {epoch}, VAL Iter: {i}, "
|
||||||
|
f"L1 Loss: {float(loss)} MSE Loss: {float(mse_loss)}, "
|
||||||
|
f"MS-SSIM: {float(ms_ssim_val)}, SSIM: {float(ssim_val)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if use_vdl and i == 0:
|
||||||
|
img_0 = to_image(out[0])
|
||||||
|
img_gt_0 = to_image(out_gt[0])
|
||||||
|
vdl_writer.add_image("Val/Predicted Image No.0", img_0, epoch)
|
||||||
|
vdl_writer.add_image("Val/Target Image No.0", img_gt_0, epoch)
|
||||||
|
|
||||||
|
img_1 = to_image(out[1])
|
||||||
|
img_gt_1 = to_image(out_gt[1])
|
||||||
|
img_gt_1 = img_gt_1.astype("uint8")
|
||||||
|
vdl_writer.add_image("Val/Predicted Image No.1", img_1, epoch)
|
||||||
|
vdl_writer.add_image("Val/Target Image No.1", img_gt_1, epoch)
|
||||||
|
|
||||||
|
img_2 = to_image(out[2])
|
||||||
|
img_gt_2 = to_image(out_gt[2])
|
||||||
|
vdl_writer.add_image("Val/Predicted Image No.2", img_2, epoch)
|
||||||
|
vdl_writer.add_image("Val/Target Image No.2", img_gt_2, epoch)
|
||||||
|
|
||||||
|
avg_ssim /= len(val_loader)
|
||||||
|
avg_ms_ssim /= len(val_loader)
|
||||||
|
avg_l1_loss /= len(val_loader)
|
||||||
|
avg_mse_loss /= len(val_loader)
|
||||||
|
|
||||||
|
if use_vdl:
|
||||||
|
vdl_writer.add_scalar("Val/L1 Loss", float(loss), epoch)
|
||||||
|
vdl_writer.add_scalar("Val/MSE Loss", float(mse_loss), epoch)
|
||||||
|
vdl_writer.add_scalar("Val/SSIM", float(ssim_val), epoch)
|
||||||
|
vdl_writer.add_scalar("Val/MS-SSIM", float(ms_ssim_val), epoch)
|
||||||
|
|
||||||
|
# Save
|
||||||
|
ckpt = {
|
||||||
|
"model": model.state_dict(),
|
||||||
|
"optimizer": optimizer.state_dict(),
|
||||||
|
"scheduler": scheduler.state_dict(),
|
||||||
|
"best_fitness": best_fitness,
|
||||||
|
"epoch": epoch,
|
||||||
|
}
|
||||||
|
|
||||||
|
paddle.save(ckpt, str(last))
|
||||||
|
|
||||||
|
if best_fitness < avg_ssim:
|
||||||
|
best_fitness = avg_ssim
|
||||||
|
paddle.save(ckpt, str(best))
|
||||||
|
|
||||||
|
if use_vdl:
|
||||||
|
vdl_writer.close()
|
||||||
|
|
||||||
|
|
||||||
|
def main(args):
|
||||||
|
print_args(vars(args))
|
||||||
|
|
||||||
|
args.save_dir = str(
|
||||||
|
increment_path(Path(args.project) / args.name, exist_ok=args.exist_ok)
|
||||||
|
)
|
||||||
|
|
||||||
|
train(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="Hyperparams")
|
||||||
|
parser.add_argument(
|
||||||
|
"--data-root",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default="~/datasets/doc3d",
|
||||||
|
help="The root path of the dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--img-size",
|
||||||
|
nargs="?",
|
||||||
|
type=int,
|
||||||
|
default=288,
|
||||||
|
help="The size of the input image",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--epochs",
|
||||||
|
nargs="?",
|
||||||
|
type=int,
|
||||||
|
default=65,
|
||||||
|
help="The number of training epochs",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--batch-size", nargs="?", type=int, default=12, help="Batch Size"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--lr", nargs="?", type=float, default=1e-04, help="Learning Rate"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--resume",
|
||||||
|
nargs="?",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to previous saved model to restart from",
|
||||||
|
)
|
||||||
|
parser.add_argument("--workers", type=int, default=8, help="max dataloader workers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--project", default=ROOT / "runs/train", help="save to project/name"
|
||||||
|
)
|
||||||
|
parser.add_argument("--name", default="exp", help="save to project/name")
|
||||||
|
parser.add_argument(
|
||||||
|
"--exist-ok",
|
||||||
|
action="store_true",
|
||||||
|
help="existing project/name ok, do not increment",
|
||||||
|
)
|
||||||
|
parser.add_argument("--seed", type=int, default=0, help="Global training seed")
|
||||||
|
parser.add_argument("--use-vdl", action="store_true", help="use VisualDL as logger")
|
||||||
|
args = parser.parse_args()
|
||||||
|
main(args)
|
||||||
10
doc_dewarp/train.sh
Normal file
10
doc_dewarp/train.sh
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
export OPENCV_IO_ENABLE_OPENEXR=1
|
||||||
|
export FLAGS_logtostderr=0
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
|
||||||
|
python train.py --img-size 288 \
|
||||||
|
--name "DocTr++" \
|
||||||
|
--batch-size 12 \
|
||||||
|
--lr 1e-4 \
|
||||||
|
--exist-ok \
|
||||||
|
--use-vdl
|
||||||
67
doc_dewarp/utils.py
Normal file
67
doc_dewarp/utils.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
from paddle.nn import functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def to_tensor(img: np.ndarray):
|
||||||
|
"""
|
||||||
|
Converts a numpy array image (HWC) to a Paddle tensor (NCHW).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (numpy.ndarray): The input image as a numpy array.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out (paddle.Tensor): The output tensor.
|
||||||
|
"""
|
||||||
|
img = img[:, :, ::-1]
|
||||||
|
img = img.astype("float32") / 255.0
|
||||||
|
img = img.transpose(2, 0, 1)
|
||||||
|
out: paddle.Tensor = paddle.to_tensor(img)
|
||||||
|
out = paddle.unsqueeze(out, axis=0)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def to_image(x: paddle.Tensor):
|
||||||
|
"""
|
||||||
|
Converts a Paddle tensor (NCHW) to a numpy array image (HWC).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (paddle.Tensor): The input tensor.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out (numpy.ndarray): The output image as a numpy array.
|
||||||
|
"""
|
||||||
|
out: np.ndarray = x.squeeze().numpy()
|
||||||
|
out = out.transpose(1, 2, 0)
|
||||||
|
out = out * 255.0
|
||||||
|
out = out.astype("uint8")
|
||||||
|
out = out[:, :, ::-1]
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def unwarp(img, bm, bm_data_format="NCHW"):
|
||||||
|
"""
|
||||||
|
Unwarp an image using a flow field.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img (paddle.Tensor): The input image.
|
||||||
|
bm (paddle.Tensor): The flow field.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
out (paddle.Tensor): The output image.
|
||||||
|
"""
|
||||||
|
_, _, h, w = img.shape
|
||||||
|
|
||||||
|
if bm_data_format == "NHWC":
|
||||||
|
bm = bm.transpose([0, 3, 1, 2])
|
||||||
|
|
||||||
|
# NCHW
|
||||||
|
bm = F.upsample(bm, size=(h, w), mode="bilinear", align_corners=True)
|
||||||
|
# NHWC
|
||||||
|
bm = bm.transpose([0, 2, 3, 1])
|
||||||
|
# NCHW
|
||||||
|
out = F.grid_sample(img, bm)
|
||||||
|
|
||||||
|
return out
|
||||||
152
doc_dewarp/weight_init.py
Normal file
152
doc_dewarp/weight_init.py
Normal file
@@ -0,0 +1,152 @@
|
|||||||
|
import math
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import paddle
|
||||||
|
import paddle.nn.initializer as init
|
||||||
|
from scipy import special
|
||||||
|
|
||||||
|
|
||||||
|
def weight_init_(
|
||||||
|
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
In-place params init function.
|
||||||
|
Usage:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
import paddle
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
data = np.ones([3, 4], dtype='float32')
|
||||||
|
linear = paddle.nn.Linear(4, 4)
|
||||||
|
input = paddle.to_tensor(data)
|
||||||
|
print(linear.weight)
|
||||||
|
linear(input)
|
||||||
|
|
||||||
|
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
|
||||||
|
print(linear.weight)
|
||||||
|
"""
|
||||||
|
|
||||||
|
if hasattr(layer, "weight") and layer.weight is not None:
|
||||||
|
getattr(init, func)(**kwargs)(layer.weight)
|
||||||
|
if weight_name is not None:
|
||||||
|
# override weight name
|
||||||
|
layer.weight.name = weight_name
|
||||||
|
|
||||||
|
if hasattr(layer, "bias") and layer.bias is not None:
|
||||||
|
init.Constant(bias_value)(layer.bias)
|
||||||
|
if bias_name is not None:
|
||||||
|
# override bias name
|
||||||
|
layer.bias.name = bias_name
|
||||||
|
|
||||||
|
|
||||||
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
||||||
|
def norm_cdf(x):
|
||||||
|
# Computes standard normal cumulative distribution function
|
||||||
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
||||||
|
|
||||||
|
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
||||||
|
print(
|
||||||
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
||||||
|
"The distribution of values may be incorrect."
|
||||||
|
)
|
||||||
|
|
||||||
|
with paddle.no_grad():
|
||||||
|
# Values are generated by using a truncated uniform distribution and
|
||||||
|
# then using the inverse CDF for the normal distribution.
|
||||||
|
# Get upper and lower cdf values
|
||||||
|
lower = norm_cdf((a - mean) / std)
|
||||||
|
upper = norm_cdf((b - mean) / std)
|
||||||
|
|
||||||
|
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
|
||||||
|
tmp = np.random.uniform(
|
||||||
|
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
|
||||||
|
).astype(np.float32)
|
||||||
|
|
||||||
|
# Use inverse cdf transform for normal distribution to get truncated
|
||||||
|
# standard normal
|
||||||
|
tmp = special.erfinv(tmp)
|
||||||
|
|
||||||
|
# Transform to proper mean, std
|
||||||
|
tmp *= std * math.sqrt(2.0)
|
||||||
|
tmp += mean
|
||||||
|
|
||||||
|
# Clamp to ensure it's in the proper range
|
||||||
|
tmp = np.clip(tmp, a, b)
|
||||||
|
tensor.set_value(paddle.to_tensor(tmp))
|
||||||
|
|
||||||
|
return tensor
|
||||||
|
|
||||||
|
|
||||||
|
def _calculate_fan_in_and_fan_out(tensor):
|
||||||
|
dimensions = tensor.dim()
|
||||||
|
if dimensions < 2:
|
||||||
|
raise ValueError(
|
||||||
|
"Fan in and fan out can not be computed for tensor "
|
||||||
|
"with fewer than 2 dimensions"
|
||||||
|
)
|
||||||
|
|
||||||
|
num_input_fmaps = tensor.shape[1]
|
||||||
|
num_output_fmaps = tensor.shape[0]
|
||||||
|
receptive_field_size = 1
|
||||||
|
if tensor.dim() > 2:
|
||||||
|
receptive_field_size = tensor[0][0].numel()
|
||||||
|
fan_in = num_input_fmaps * receptive_field_size
|
||||||
|
fan_out = num_output_fmaps * receptive_field_size
|
||||||
|
|
||||||
|
return fan_in, fan_out
|
||||||
|
|
||||||
|
|
||||||
|
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
|
||||||
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
||||||
|
|
||||||
|
|
||||||
|
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
|
||||||
|
def _calculate_correct_fan(tensor, mode):
|
||||||
|
mode = mode.lower()
|
||||||
|
valid_modes = ["fan_in", "fan_out"]
|
||||||
|
if mode not in valid_modes:
|
||||||
|
raise ValueError(
|
||||||
|
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
|
||||||
|
)
|
||||||
|
|
||||||
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||||
|
return fan_in if mode == "fan_in" else fan_out
|
||||||
|
|
||||||
|
def calculate_gain(nonlinearity, param=None):
|
||||||
|
linear_fns = [
|
||||||
|
"linear",
|
||||||
|
"conv1d",
|
||||||
|
"conv2d",
|
||||||
|
"conv3d",
|
||||||
|
"conv_transpose1d",
|
||||||
|
"conv_transpose2d",
|
||||||
|
"conv_transpose3d",
|
||||||
|
]
|
||||||
|
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
|
||||||
|
return 1
|
||||||
|
elif nonlinearity == "tanh":
|
||||||
|
return 5.0 / 3
|
||||||
|
elif nonlinearity == "relu":
|
||||||
|
return math.sqrt(2.0)
|
||||||
|
elif nonlinearity == "leaky_relu":
|
||||||
|
if param is None:
|
||||||
|
negative_slope = 0.01
|
||||||
|
elif (
|
||||||
|
not isinstance(param, bool)
|
||||||
|
and isinstance(param, int)
|
||||||
|
or isinstance(param, float)
|
||||||
|
):
|
||||||
|
negative_slope = param
|
||||||
|
else:
|
||||||
|
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||||
|
return math.sqrt(2.0 / (1 + negative_slope**2))
|
||||||
|
else:
|
||||||
|
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||||
|
|
||||||
|
fan = _calculate_correct_fan(tensor, mode)
|
||||||
|
gain = calculate_gain(nonlinearity, a)
|
||||||
|
std = gain / math.sqrt(fan)
|
||||||
|
with paddle.no_grad():
|
||||||
|
paddle.nn.initializer.Normal(0, std)(tensor)
|
||||||
|
return tensor
|
||||||
Reference in New Issue
Block a user