diff --git a/model/.gitkeep b/dewarp/__init__.py similarity index 100% rename from model/.gitkeep rename to dewarp/__init__.py diff --git a/dewarp/dewarp.py b/dewarp/dewarp.py new file mode 100644 index 0000000..d8bce43 --- /dev/null +++ b/dewarp/dewarp.py @@ -0,0 +1,96 @@ +import os + +import cv2 +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable + +from dewarp.models import get_model +from dewarp.utils import convert_state_dict + +DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + +def unwarp(img, bm): + w, h = img.shape[0], img.shape[1] + bm = bm.transpose(1, 2).transpose(2, 3).detach().cpu().numpy()[0, :, :, :] + bm0 = cv2.blur(bm[:, :, 0], (3, 3)) + bm1 = cv2.blur(bm[:, :, 1], (3, 3)) + bm0 = cv2.resize(bm0, (h, w)) + bm1 = cv2.resize(bm1, (h, w)) + bm = np.stack([bm0, bm1], axis=-1) + bm = np.expand_dims(bm, 0) + bm = torch.from_numpy(bm).double() + + img = img.astype(float) / 255.0 + img = img.transpose((2, 0, 1)) + img = np.expand_dims(img, 0) + img = torch.from_numpy(img).double() + + res = F.grid_sample(input=img, grid=bm, align_corners=True) + res = res[0].numpy().transpose((1, 2, 0)) + + return res + + +def dewarp_image(image): + wc_model_path = "model/dewarp_model/unetnc_doc3d.pkl" + bm_model_path = "model/dewarp_model/dnetccnl_doc3d.pkl" + wc_model_file_name = os.path.split(wc_model_path)[1] + wc_model_name = wc_model_file_name[:wc_model_file_name.find('_')] + + bm_model_file_name = os.path.split(bm_model_path)[1] + bm_model_name = bm_model_file_name[:bm_model_file_name.find('_')] + + wc_n_classes = 3 + bm_n_classes = 2 + + wc_img_size = (256, 256) + bm_img_size = (128, 128) + + # Setup image + imgorg = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + img = cv2.resize(imgorg, wc_img_size) + + img = img[:, :, ::-1] + img = img.astype(float) / 255.0 + img = img.transpose(2, 0, 1) # NHWC -> NCHW + img = np.expand_dims(img, 0) + img = torch.from_numpy(img).float() + + # Predict + htan = nn.Hardtanh(0, 1.0) + wc_model = get_model(wc_model_name, wc_n_classes, in_channels=3) + if DEVICE.type == 'cpu': + wc_state = convert_state_dict(torch.load(wc_model_path, map_location='cpu', weights_only=True)['model_state']) + else: + wc_state = convert_state_dict(torch.load(wc_model_path, weights_only=True)['model_state']) + wc_model.load_state_dict(wc_state) + wc_model.eval() + bm_model = get_model(bm_model_name, bm_n_classes, in_channels=3) + if DEVICE.type == 'cpu': + bm_state = convert_state_dict(torch.load(bm_model_path, map_location='cpu', weights_only=True)['model_state']) + else: + bm_state = convert_state_dict(torch.load(bm_model_path, weights_only=True)['model_state']) + bm_model.load_state_dict(bm_state) + bm_model.eval() + + if torch.cuda.is_available(): + wc_model.cuda() + bm_model.cuda() + images = Variable(img.cuda()) + else: + images = Variable(img) + + with torch.no_grad(): + wc_outputs = wc_model(images) + pred_wc = htan(wc_outputs) + bm_input = F.interpolate(pred_wc, bm_img_size) + outputs_bm = bm_model(bm_input) + + # call unwarp + uwpred = unwarp(imgorg, outputs_bm) + uwpred = (uwpred * 255).astype(np.uint8) + return cv2.cvtColor(uwpred, cv2.COLOR_RGB2BGR) diff --git a/dewarp/models/__init__.py b/dewarp/models/__init__.py new file mode 100644 index 0000000..50b9225 --- /dev/null +++ b/dewarp/models/__init__.py @@ -0,0 +1,24 @@ +from dewarp.models.densenetccnl import dnetccnl +from dewarp.models.unetnc import UnetGenerator + + +def get_model(name, n_classes=1, in_channels=3): + model = _get_model_instance(name) + + if name == 'dnetccnl': + model = model(img_size=128, in_channels=in_channels, out_channels=n_classes, filters=32) + elif name == 'unetnc': + model = model(input_nc=in_channels, output_nc=n_classes, num_downs=7) + else: + model = model(n_classes=n_classes) + return model + + +def _get_model_instance(name): + try: + return { + 'dnetccnl': dnetccnl, + 'unetnc': UnetGenerator, + }[name] + except: + print('Model {} not available'.format(name)) diff --git a/dewarp/models/densenetccnl.py b/dewarp/models/densenetccnl.py new file mode 100644 index 0000000..d92ed92 --- /dev/null +++ b/dewarp/models/densenetccnl.py @@ -0,0 +1,243 @@ +# Densenet decoder encoder with intermediate fully connected layers and dropout + +import numpy as np +import torch +import torch.nn as nn + + +def add_coordConv_channels(t): + n, c, h, w = t.size() + xx_channel = np.ones((h, w)) + xx_range = np.array(range(h)) + xx_range = np.expand_dims(xx_range, -1) + xx_coord = xx_channel * xx_range + yy_coord = xx_coord.transpose() + + xx_coord = xx_coord / (h - 1) + yy_coord = yy_coord / (h - 1) + xx_coord = xx_coord * 2 - 1 + yy_coord = yy_coord * 2 - 1 + xx_coord = torch.from_numpy(xx_coord).float() + yy_coord = torch.from_numpy(yy_coord).float() + + if t.is_cuda: + xx_coord = xx_coord.cuda() + yy_coord = yy_coord.cuda() + + xx_coord = xx_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1) + yy_coord = yy_coord.unsqueeze(0).unsqueeze(0).repeat(n, 1, 1, 1) + + t_cc = torch.cat((t, xx_coord, yy_coord), dim=1) + + return t_cc + + +class DenseBlockEncoder(nn.Module): + def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): + super(DenseBlockEncoder, self).__init__() + assert (n_convs > 0) + + self.n_channels = n_channels + self.n_convs = n_convs + self.layers = nn.ModuleList() + for i in range(n_convs): + self.layers.append(nn.Sequential( + nn.BatchNorm2d(n_channels), + activation(*args), + nn.Conv2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), )) + + def forward(self, inputs): + outputs = [] + + for i, layer in enumerate(self.layers): + if i > 0: + next_output = 0 + for no in outputs: + next_output = next_output + no + outputs.append(next_output) + else: + outputs.append(layer(inputs)) + return outputs[-1] + + +# Dense block in encoder. +class DenseBlockDecoder(nn.Module): + def __init__(self, n_channels, n_convs, activation=nn.ReLU, args=[False]): + super(DenseBlockDecoder, self).__init__() + assert (n_convs > 0) + + self.n_channels = n_channels + self.n_convs = n_convs + self.layers = nn.ModuleList() + for i in range(n_convs): + self.layers.append(nn.Sequential( + nn.BatchNorm2d(n_channels), + activation(*args), + nn.ConvTranspose2d(n_channels, n_channels, 3, stride=1, padding=1, bias=False), )) + + def forward(self, inputs): + outputs = [] + + for i, layer in enumerate(self.layers): + if i > 0: + next_output = 0 + for no in outputs: + next_output = next_output + no + outputs.append(next_output) + else: + outputs.append(layer(inputs)) + return outputs[-1] + + +class DenseTransitionBlockEncoder(nn.Module): + def __init__(self, n_channels_in, n_channels_out, mp, activation=nn.ReLU, args=[False]): + super(DenseTransitionBlockEncoder, self).__init__() + self.n_channels_in = n_channels_in + self.n_channels_out = n_channels_out + self.mp = mp + self.main = nn.Sequential( + nn.BatchNorm2d(n_channels_in), + activation(*args), + nn.Conv2d(n_channels_in, n_channels_out, 1, stride=1, padding=0, bias=False), + nn.MaxPool2d(mp), + ) + + def forward(self, inputs): + return self.main(inputs) + + +class DenseTransitionBlockDecoder(nn.Module): + def __init__(self, n_channels_in, n_channels_out, activation=nn.ReLU, args=[False]): + super(DenseTransitionBlockDecoder, self).__init__() + self.n_channels_in = n_channels_in + self.n_channels_out = n_channels_out + self.main = nn.Sequential( + nn.BatchNorm2d(n_channels_in), + activation(*args), + nn.ConvTranspose2d(n_channels_in, n_channels_out, 4, stride=2, padding=1, bias=False), + ) + + def forward(self, inputs): + return self.main(inputs) + + +## Dense encoders and decoders for image of size 128 128 +class waspDenseEncoder128(nn.Module): + def __init__(self, nc=1, ndf=32, ndim=128, activation=nn.LeakyReLU, args=[0.2, False], f_activation=nn.Tanh, + f_args=[]): + super(waspDenseEncoder128, self).__init__() + self.ndim = ndim + + self.main = nn.Sequential( + # input is (nc) x 128 x 128 + nn.BatchNorm2d(nc), + nn.ReLU(True), + nn.Conv2d(nc, ndf, 4, stride=2, padding=1), + + # state size. (ndf) x 64 x 64 + DenseBlockEncoder(ndf, 6), + DenseTransitionBlockEncoder(ndf, ndf * 2, 2, activation=activation, args=args), + + # state size. (ndf*2) x 32 x 32 + DenseBlockEncoder(ndf * 2, 12), + DenseTransitionBlockEncoder(ndf * 2, ndf * 4, 2, activation=activation, args=args), + + # state size. (ndf*4) x 16 x 16 + DenseBlockEncoder(ndf * 4, 16), + DenseTransitionBlockEncoder(ndf * 4, ndf * 8, 2, activation=activation, args=args), + + # state size. (ndf*4) x 8 x 8 + DenseBlockEncoder(ndf * 8, 16), + DenseTransitionBlockEncoder(ndf * 8, ndf * 8, 2, activation=activation, args=args), + + # state size. (ndf*8) x 4 x 4 + DenseBlockEncoder(ndf * 8, 16), + DenseTransitionBlockEncoder(ndf * 8, ndim, 4, activation=activation, args=args), + f_activation(*f_args), + ) + + def forward(self, input): + input = add_coordConv_channels(input) + output = self.main(input).view(-1, self.ndim) + # print(output.size()) + return output + + +class waspDenseDecoder128(nn.Module): + def __init__(self, nz=128, nc=1, ngf=32, lb=0, ub=1, activation=nn.ReLU, args=[False], f_activation=nn.Hardtanh, + f_args=[]): + super(waspDenseDecoder128, self).__init__() + self.main = nn.Sequential( + # input is Z, going into convolution + nn.BatchNorm2d(nz), + activation(*args), + nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False), + + # state size. (ngf*8) x 4 x 4 + DenseBlockDecoder(ngf * 8, 16), + DenseTransitionBlockDecoder(ngf * 8, ngf * 8), + + # state size. (ngf*4) x 8 x 8 + DenseBlockDecoder(ngf * 8, 16), + DenseTransitionBlockDecoder(ngf * 8, ngf * 4), + + # state size. (ngf*2) x 16 x 16 + DenseBlockDecoder(ngf * 4, 12), + DenseTransitionBlockDecoder(ngf * 4, ngf * 2), + + # state size. (ngf) x 32 x 32 + DenseBlockDecoder(ngf * 2, 6), + DenseTransitionBlockDecoder(ngf * 2, ngf), + + # state size. (ngf) x 64 x 64 + DenseBlockDecoder(ngf, 6), + DenseTransitionBlockDecoder(ngf, ngf), + + # state size (ngf) x 128 x 128 + nn.BatchNorm2d(ngf), + activation(*args), + nn.ConvTranspose2d(ngf, nc, 3, stride=1, padding=1, bias=False), + f_activation(*f_args), + ) + # self.smooth=nn.Sequential( + # nn.Conv2d(nc, nc, 1, stride=1, padding=0, bias=False), + # f_activation(*f_args), + # ) + + def forward(self, inputs): + # return self.smooth(self.main(inputs)) + return self.main(inputs) + + +class dnetccnl(nn.Module): + # in_channels -> nc | encoder first layer + # filters -> ndf | encoder first layer + # img_size(h,w) -> ndim + # out_channels -> optical flow (x,y) + + def __init__(self, img_size=128, in_channels=1, out_channels=2, filters=32, fc_units=100): + super(dnetccnl, self).__init__() + self.nc = in_channels + self.nf = filters + self.ndim = img_size + self.oc = out_channels + self.fcu = fc_units + + self.encoder = waspDenseEncoder128(nc=self.nc + 2, ndf=self.nf, ndim=self.ndim) + self.decoder = waspDenseDecoder128(nz=self.ndim, nc=self.oc, ngf=self.nf) + # self.fc_layers= nn.Sequential(nn.Linear(self.ndim, self.fcu), + # nn.ReLU(True), + # nn.Dropout(0.25), + # nn.Linear(self.fcu,self.ndim), + # nn.ReLU(True), + # nn.Dropout(0.25), + # ) + + def forward(self, inputs): + encoded = self.encoder(inputs) + encoded = encoded.unsqueeze(-1).unsqueeze(-1) + decoded = self.decoder(encoded) + # print torch.max(decoded) + # print torch.min(decoded) + + return decoded diff --git a/dewarp/models/unetnc.py b/dewarp/models/unetnc.py new file mode 100644 index 0000000..b6c38dc --- /dev/null +++ b/dewarp/models/unetnc.py @@ -0,0 +1,89 @@ +import functools + +import torch +import torch.nn as nn + + +# Defines the Unet generator. +# |num_downs|: number of downsamplings in UNet. For example, +# if |num_downs| == 7, image of size 128x128 will become of size 1x1 +# at the bottleneck +class UnetGenerator(nn.Module): + def __init__(self, input_nc, output_nc, num_downs, ngf=64, + norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetGenerator, self).__init__() + + # construct unet structure + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, + innermost=True) + for i in range(num_downs - 5): + unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer, use_dropout=use_dropout) + unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, + norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer) + unet_block = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, + norm_layer=norm_layer) + + self.model = unet_block + + def forward(self, input): + return self.model(input) + + +# Defines the submodule with skip connection. +# X -------------------identity---------------------- X +# |-- downsampling -- |submodule| -- upsampling --| +class UnetSkipConnectionBlock(nn.Module): + def __init__(self, outer_nc, inner_nc, input_nc=None, + submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False): + super(UnetSkipConnectionBlock, self).__init__() + self.outermost = outermost + if type(norm_layer) == functools.partial: + use_bias = norm_layer.func == nn.InstanceNorm2d + else: + use_bias = norm_layer == nn.InstanceNorm2d + if input_nc is None: + input_nc = outer_nc + downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4, + stride=2, padding=1, bias=use_bias) + downrelu = nn.LeakyReLU(0.2, True) + downnorm = norm_layer(inner_nc) + uprelu = nn.ReLU(True) + upnorm = norm_layer(outer_nc) + + if outermost: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1) + down = [downconv] + up = [uprelu, upconv, nn.Tanh()] + model = down + [submodule] + up + elif innermost: + upconv = nn.ConvTranspose2d(inner_nc, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv] + up = [uprelu, upconv, upnorm] + model = down + up + else: + upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc, + kernel_size=4, stride=2, + padding=1, bias=use_bias) + down = [downrelu, downconv, downnorm] + up = [uprelu, upconv, upnorm] + + if use_dropout: + model = down + [submodule] + up + [nn.Dropout(0.5)] + else: + model = down + [submodule] + up + + self.model = nn.Sequential(*model) + + def forward(self, x): + if self.outermost: + return self.model(x) + else: + return torch.cat([x, self.model(x)], 1) diff --git a/dewarp/utils.py b/dewarp/utils.py new file mode 100644 index 0000000..65c0c4b --- /dev/null +++ b/dewarp/utils.py @@ -0,0 +1,17 @@ +''' +Misc Utility functions +''' +from collections import OrderedDict + + +def convert_state_dict(state_dict): + """Converts a state dict saved from a dataParallel module to normal + module state_dict inplace + :param state_dict is the loaded DataParallel model_state + + """ + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + return new_state_dict