文档检测
This commit is contained in:
73
object_detection/core/models/CornerNet.py
Normal file
73
object_detection/core/models/CornerNet.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .py_utils import TopPool, BottomPool, LeftPool, RightPool
|
||||
from .py_utils.losses import CornerNet_Loss
|
||||
from .py_utils.modules import hg_module, hg, hg_net
|
||||
from .py_utils.utils import convolution, residual, corner_pool
|
||||
|
||||
|
||||
def make_pool_layer(dim):
|
||||
return nn.Sequential()
|
||||
|
||||
|
||||
def make_hg_layer(inp_dim, out_dim, modules):
|
||||
layers = [residual(inp_dim, out_dim, stride=2)]
|
||||
layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class model(hg_net):
|
||||
def _pred_mod(self, dim):
|
||||
return nn.Sequential(
|
||||
convolution(3, 256, 256, with_bn=False),
|
||||
nn.Conv2d(256, dim, (1, 1))
|
||||
)
|
||||
|
||||
def _merge_mod(self):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(256, 256, (1, 1), bias=False),
|
||||
nn.BatchNorm2d(256)
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
stacks = 2
|
||||
pre = nn.Sequential(
|
||||
convolution(7, 3, 128, stride=2),
|
||||
residual(128, 256, stride=2)
|
||||
)
|
||||
hg_mods = nn.ModuleList([
|
||||
hg_module(
|
||||
5, [256, 256, 384, 384, 384, 512], [2, 2, 2, 2, 2, 4],
|
||||
make_pool_layer=make_pool_layer,
|
||||
make_hg_layer=make_hg_layer
|
||||
) for _ in range(stacks)
|
||||
])
|
||||
cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)])
|
||||
inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)])
|
||||
cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
|
||||
hgs = hg(pre, hg_mods, cnvs, inters, cnvs_, inters_)
|
||||
|
||||
tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)])
|
||||
br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)])
|
||||
|
||||
tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
for tl_heat, br_heat in zip(tl_heats, br_heats):
|
||||
torch.nn.init.constant_(tl_heat[-1].bias, -2.19)
|
||||
torch.nn.init.constant_(br_heat[-1].bias, -2.19)
|
||||
|
||||
tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
|
||||
tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
|
||||
super(model, self).__init__(
|
||||
hgs, tl_modules, br_modules, tl_heats, br_heats,
|
||||
tl_tags, br_tags, tl_offs, br_offs
|
||||
)
|
||||
|
||||
self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1)
|
||||
93
object_detection/core/models/CornerNet_Saccade.py
Normal file
93
object_detection/core/models/CornerNet_Saccade.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .py_utils import TopPool, BottomPool, LeftPool, RightPool
|
||||
from .py_utils.losses import CornerNet_Saccade_Loss
|
||||
from .py_utils.modules import saccade_net, saccade_module, saccade
|
||||
from .py_utils.utils import convolution, residual, corner_pool
|
||||
|
||||
|
||||
def make_pool_layer(dim):
|
||||
return nn.Sequential()
|
||||
|
||||
|
||||
def make_hg_layer(inp_dim, out_dim, modules):
|
||||
layers = [residual(inp_dim, out_dim, stride=2)]
|
||||
layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class model(saccade_net):
|
||||
def _pred_mod(self, dim):
|
||||
return nn.Sequential(
|
||||
convolution(3, 256, 256, with_bn=False),
|
||||
nn.Conv2d(256, dim, (1, 1))
|
||||
)
|
||||
|
||||
def _merge_mod(self):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(256, 256, (1, 1), bias=False),
|
||||
nn.BatchNorm2d(256)
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
stacks = 3
|
||||
pre = nn.Sequential(
|
||||
convolution(7, 3, 128, stride=2),
|
||||
residual(128, 256, stride=2)
|
||||
)
|
||||
hg_mods = nn.ModuleList([
|
||||
saccade_module(
|
||||
3, [256, 384, 384, 512], [1, 1, 1, 1],
|
||||
make_pool_layer=make_pool_layer,
|
||||
make_hg_layer=make_hg_layer
|
||||
) for _ in range(stacks)
|
||||
])
|
||||
cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)])
|
||||
inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)])
|
||||
cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
|
||||
att_mods = nn.ModuleList([
|
||||
nn.ModuleList([
|
||||
nn.Sequential(
|
||||
convolution(3, 384, 256, with_bn=False),
|
||||
nn.Conv2d(256, 1, (1, 1))
|
||||
),
|
||||
nn.Sequential(
|
||||
convolution(3, 384, 256, with_bn=False),
|
||||
nn.Conv2d(256, 1, (1, 1))
|
||||
),
|
||||
nn.Sequential(
|
||||
convolution(3, 256, 256, with_bn=False),
|
||||
nn.Conv2d(256, 1, (1, 1))
|
||||
)
|
||||
]) for _ in range(stacks)
|
||||
])
|
||||
for att_mod in att_mods:
|
||||
for att in att_mod:
|
||||
torch.nn.init.constant_(att[-1].bias, -2.19)
|
||||
|
||||
hgs = saccade(pre, hg_mods, cnvs, inters, cnvs_, inters_)
|
||||
|
||||
tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)])
|
||||
br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)])
|
||||
|
||||
tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
for tl_heat, br_heat in zip(tl_heats, br_heats):
|
||||
torch.nn.init.constant_(tl_heat[-1].bias, -2.19)
|
||||
torch.nn.init.constant_(br_heat[-1].bias, -2.19)
|
||||
|
||||
tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
|
||||
tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
|
||||
super(model, self).__init__(
|
||||
hgs, tl_modules, br_modules, tl_heats, br_heats,
|
||||
tl_tags, br_tags, tl_offs, br_offs, att_mods
|
||||
)
|
||||
|
||||
self.loss = CornerNet_Saccade_Loss(pull_weight=1e-1, push_weight=1e-1)
|
||||
117
object_detection/core/models/CornerNet_Squeeze.py
Normal file
117
object_detection/core/models/CornerNet_Squeeze.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .py_utils import TopPool, BottomPool, LeftPool, RightPool
|
||||
from .py_utils.losses import CornerNet_Loss
|
||||
from .py_utils.modules import hg_module, hg, hg_net
|
||||
from .py_utils.utils import convolution, corner_pool, residual
|
||||
|
||||
|
||||
class fire_module(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, sr=2, stride=1):
|
||||
super(fire_module, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inp_dim, out_dim // sr, kernel_size=1, stride=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_dim // sr)
|
||||
self.conv_1x1 = nn.Conv2d(out_dim // sr, out_dim // 2, kernel_size=1, stride=stride, bias=False)
|
||||
self.conv_3x3 = nn.Conv2d(out_dim // sr, out_dim // 2, kernel_size=3, padding=1,
|
||||
stride=stride, groups=out_dim // sr, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_dim)
|
||||
self.skip = (stride == 1 and inp_dim == out_dim)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
conv1 = self.conv1(x)
|
||||
bn1 = self.bn1(conv1)
|
||||
conv2 = torch.cat((self.conv_1x1(bn1), self.conv_3x3(bn1)), 1)
|
||||
bn2 = self.bn2(conv2)
|
||||
if self.skip:
|
||||
return self.relu(bn2 + x)
|
||||
else:
|
||||
return self.relu(bn2)
|
||||
|
||||
|
||||
def make_pool_layer(dim):
|
||||
return nn.Sequential()
|
||||
|
||||
|
||||
def make_unpool_layer(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, kernel_size=4, stride=2, padding=1)
|
||||
|
||||
|
||||
def make_layer(inp_dim, out_dim, modules):
|
||||
layers = [fire_module(inp_dim, out_dim)]
|
||||
layers += [fire_module(out_dim, out_dim) for _ in range(1, modules)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def make_layer_revr(inp_dim, out_dim, modules):
|
||||
layers = [fire_module(inp_dim, inp_dim) for _ in range(modules - 1)]
|
||||
layers += [fire_module(inp_dim, out_dim)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def make_hg_layer(inp_dim, out_dim, modules):
|
||||
layers = [fire_module(inp_dim, out_dim, stride=2)]
|
||||
layers += [fire_module(out_dim, out_dim) for _ in range(1, modules)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class model(hg_net):
|
||||
def _pred_mod(self, dim):
|
||||
return nn.Sequential(
|
||||
convolution(1, 256, 256, with_bn=False),
|
||||
nn.Conv2d(256, dim, (1, 1))
|
||||
)
|
||||
|
||||
def _merge_mod(self):
|
||||
return nn.Sequential(
|
||||
nn.Conv2d(256, 256, (1, 1), bias=False),
|
||||
nn.BatchNorm2d(256)
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
stacks = 2
|
||||
pre = nn.Sequential(
|
||||
convolution(7, 3, 128, stride=2),
|
||||
residual(128, 256, stride=2),
|
||||
residual(256, 256, stride=2)
|
||||
)
|
||||
hg_mods = nn.ModuleList([
|
||||
hg_module(
|
||||
4, [256, 256, 384, 384, 512], [2, 2, 2, 2, 4],
|
||||
make_pool_layer=make_pool_layer,
|
||||
make_unpool_layer=make_unpool_layer,
|
||||
make_up_layer=make_layer,
|
||||
make_low_layer=make_layer,
|
||||
make_hg_layer_revr=make_layer_revr,
|
||||
make_hg_layer=make_hg_layer
|
||||
) for _ in range(stacks)
|
||||
])
|
||||
cnvs = nn.ModuleList([convolution(3, 256, 256) for _ in range(stacks)])
|
||||
inters = nn.ModuleList([residual(256, 256) for _ in range(stacks - 1)])
|
||||
cnvs_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
inters_ = nn.ModuleList([self._merge_mod() for _ in range(stacks - 1)])
|
||||
|
||||
hgs = hg(pre, hg_mods, cnvs, inters, cnvs_, inters_)
|
||||
|
||||
tl_modules = nn.ModuleList([corner_pool(256, TopPool, LeftPool) for _ in range(stacks)])
|
||||
br_modules = nn.ModuleList([corner_pool(256, BottomPool, RightPool) for _ in range(stacks)])
|
||||
|
||||
tl_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
br_heats = nn.ModuleList([self._pred_mod(80) for _ in range(stacks)])
|
||||
for tl_heat, br_heat in zip(tl_heats, br_heats):
|
||||
torch.nn.init.constant_(tl_heat[-1].bias, -2.19)
|
||||
torch.nn.init.constant_(br_heat[-1].bias, -2.19)
|
||||
|
||||
tl_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
br_tags = nn.ModuleList([self._pred_mod(1) for _ in range(stacks)])
|
||||
|
||||
tl_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
br_offs = nn.ModuleList([self._pred_mod(2) for _ in range(stacks)])
|
||||
|
||||
super(model, self).__init__(
|
||||
hgs, tl_modules, br_modules, tl_heats, br_heats,
|
||||
tl_tags, br_tags, tl_offs, br_offs
|
||||
)
|
||||
|
||||
self.loss = CornerNet_Loss(pull_weight=1e-1, push_weight=1e-1)
|
||||
0
object_detection/core/models/__init__.py
Normal file
0
object_detection/core/models/__init__.py
Normal file
1
object_detection/core/models/py_utils/__init__.py
Normal file
1
object_detection/core/models/py_utils/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from ._cpools import TopPool, BottomPool, LeftPool, RightPool
|
||||
3
object_detection/core/models/py_utils/_cpools/.gitignore
vendored
Normal file
3
object_detection/core/models/py_utils/_cpools/.gitignore
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
build/
|
||||
cpools.egg-info/
|
||||
dist/
|
||||
82
object_detection/core/models/py_utils/_cpools/__init__.py
Normal file
82
object_detection/core/models/py_utils/_cpools/__init__.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import bottom_pool
|
||||
import left_pool
|
||||
import right_pool
|
||||
import top_pool
|
||||
from torch import nn
|
||||
from torch.autograd import Function
|
||||
|
||||
|
||||
class TopPoolFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = top_pool.forward(input)[0]
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.saved_variables[0]
|
||||
output = top_pool.backward(input, grad_output)[0]
|
||||
return output
|
||||
|
||||
|
||||
class BottomPoolFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = bottom_pool.forward(input)[0]
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.saved_variables[0]
|
||||
output = bottom_pool.backward(input, grad_output)[0]
|
||||
return output
|
||||
|
||||
|
||||
class LeftPoolFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = left_pool.forward(input)[0]
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.saved_variables[0]
|
||||
output = left_pool.backward(input, grad_output)[0]
|
||||
return output
|
||||
|
||||
|
||||
class RightPoolFunction(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, input):
|
||||
output = right_pool.forward(input)[0]
|
||||
ctx.save_for_backward(input)
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, grad_output):
|
||||
input = ctx.saved_variables[0]
|
||||
output = right_pool.backward(input, grad_output)[0]
|
||||
return output
|
||||
|
||||
|
||||
class TopPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return TopPoolFunction.apply(x)
|
||||
|
||||
|
||||
class BottomPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return BottomPoolFunction.apply(x)
|
||||
|
||||
|
||||
class LeftPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return LeftPoolFunction.apply(x)
|
||||
|
||||
|
||||
class RightPool(nn.Module):
|
||||
def forward(self, x):
|
||||
return RightPoolFunction.apply(x)
|
||||
15
object_detection/core/models/py_utils/_cpools/setup.py
Normal file
15
object_detection/core/models/py_utils/_cpools/setup.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from setuptools import setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CppExtension
|
||||
|
||||
setup(
|
||||
name="cpools",
|
||||
ext_modules=[
|
||||
CppExtension("top_pool", ["src/top_pool.cpp"]),
|
||||
CppExtension("bottom_pool", ["src/bottom_pool.cpp"]),
|
||||
CppExtension("left_pool", ["src/left_pool.cpp"]),
|
||||
CppExtension("right_pool", ["src/right_pool.cpp"])
|
||||
],
|
||||
cmdclass={
|
||||
"build_ext": BuildExtension
|
||||
}
|
||||
)
|
||||
@@ -0,0 +1,80 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> pool_forward(
|
||||
at::Tensor input
|
||||
) {
|
||||
// Initialize output
|
||||
at::Tensor output = at::zeros_like(input);
|
||||
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
at::Tensor max_temp = at::slice(output, 2, ind, height);
|
||||
at::Tensor cur_temp = at::slice(output, 2, ind, height);
|
||||
at::Tensor next_temp = at::slice(output, 2, 0, height-ind);
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> pool_backward(
|
||||
at::Tensor input,
|
||||
at::Tensor grad_output
|
||||
) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(2, 0);
|
||||
auto grad_output_temp = grad_output.select(2, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kByte));
|
||||
auto max_temp = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < height - 1; ++ind) {
|
||||
input_temp = input.select(2, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, ind + 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward", &pool_forward, "Bottom Pool Forward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
m.def(
|
||||
"backward", &pool_backward, "Bottom Pool Backward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> pool_forward(
|
||||
at::Tensor input
|
||||
) {
|
||||
// Initialize output
|
||||
at::Tensor output = at::zeros_like(input);
|
||||
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
at::Tensor max_temp = at::slice(output, 3, 0, width-ind);
|
||||
at::Tensor cur_temp = at::slice(output, 3, 0, width-ind);
|
||||
at::Tensor next_temp = at::slice(output, 3, ind, width);
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> pool_backward(
|
||||
at::Tensor input,
|
||||
at::Tensor grad_output
|
||||
) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, width - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(width - 1);
|
||||
|
||||
auto output_temp = output.select(3, width - 1);
|
||||
auto grad_output_temp = grad_output.select(3, width - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kByte));
|
||||
auto max_temp = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < width; ++ind) {
|
||||
input_temp = input.select(3, width - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, width - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, width - ind - 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward", &pool_forward, "Left Pool Forward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
m.def(
|
||||
"backward", &pool_backward, "Left Pool Backward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> pool_forward(
|
||||
at::Tensor input
|
||||
) {
|
||||
// Initialize output
|
||||
at::Tensor output = at::zeros_like(input);
|
||||
|
||||
// Get width
|
||||
int64_t width = input.size(3);
|
||||
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < width; ind <<= 1) {
|
||||
at::Tensor max_temp = at::slice(output, 3, ind, width);
|
||||
at::Tensor cur_temp = at::slice(output, 3, ind, width);
|
||||
at::Tensor next_temp = at::slice(output, 3, 0, width-ind);
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> pool_backward(
|
||||
at::Tensor input,
|
||||
at::Tensor grad_output
|
||||
) {
|
||||
at::Tensor output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(3, 0);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(0);
|
||||
|
||||
auto output_temp = output.select(3, 0);
|
||||
auto grad_output_temp = grad_output.select(3, 0);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(3);
|
||||
auto gt_mask = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kByte));
|
||||
auto max_temp = torch::zeros({batch, channel, height}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 0; ind < width - 1; ++ind) {
|
||||
input_temp = input.select(3, ind + 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, ind + 1);
|
||||
|
||||
grad_output_temp = grad_output.select(3, ind + 1).unsqueeze(3);
|
||||
output.scatter_add_(3, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward", &pool_forward, "Right Pool Forward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
m.def(
|
||||
"backward", &pool_backward, "Right Pool Backward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
}
|
||||
@@ -0,0 +1,80 @@
|
||||
#include <torch/torch.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
std::vector<at::Tensor> top_pool_forward(
|
||||
at::Tensor input
|
||||
) {
|
||||
// Initialize output
|
||||
at::Tensor output = at::zeros_like(input);
|
||||
|
||||
// Get height
|
||||
int64_t height = input.size(2);
|
||||
|
||||
output.copy_(input);
|
||||
|
||||
for (int64_t ind = 1; ind < height; ind <<= 1) {
|
||||
at::Tensor max_temp = at::slice(output, 2, 0, height-ind);
|
||||
at::Tensor cur_temp = at::slice(output, 2, 0, height-ind);
|
||||
at::Tensor next_temp = at::slice(output, 2, ind, height);
|
||||
at::max_out(max_temp, cur_temp, next_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
std::vector<at::Tensor> top_pool_backward(
|
||||
at::Tensor input,
|
||||
at::Tensor grad_output
|
||||
) {
|
||||
auto output = at::zeros_like(input);
|
||||
|
||||
int32_t batch = input.size(0);
|
||||
int32_t channel = input.size(1);
|
||||
int32_t height = input.size(2);
|
||||
int32_t width = input.size(3);
|
||||
|
||||
auto max_val = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
auto max_ind = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kLong));
|
||||
|
||||
auto input_temp = input.select(2, height - 1);
|
||||
max_val.copy_(input_temp);
|
||||
|
||||
max_ind.fill_(height - 1);
|
||||
|
||||
auto output_temp = output.select(2, height - 1);
|
||||
auto grad_output_temp = grad_output.select(2, height - 1);
|
||||
output_temp.copy_(grad_output_temp);
|
||||
|
||||
auto un_max_ind = max_ind.unsqueeze(2);
|
||||
auto gt_mask = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kByte));
|
||||
auto max_temp = torch::zeros({batch, channel, width}, at::device(at::kCUDA).dtype(at::kFloat));
|
||||
for (int32_t ind = 1; ind < height; ++ind) {
|
||||
input_temp = input.select(2, height - ind - 1);
|
||||
at::gt_out(gt_mask, input_temp, max_val);
|
||||
|
||||
at::masked_select_out(max_temp, input_temp, gt_mask);
|
||||
max_val.masked_scatter_(gt_mask, max_temp);
|
||||
max_ind.masked_fill_(gt_mask, height - ind - 1);
|
||||
|
||||
grad_output_temp = grad_output.select(2, height - ind - 1).unsqueeze(2);
|
||||
output.scatter_add_(2, un_max_ind, grad_output_temp);
|
||||
}
|
||||
|
||||
return {
|
||||
output
|
||||
};
|
||||
}
|
||||
|
||||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
||||
m.def(
|
||||
"forward", &top_pool_forward, "Top Pool Forward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
m.def(
|
||||
"backward", &top_pool_backward, "Top Pool Backward",
|
||||
py::call_guard<py::gil_scoped_release>()
|
||||
);
|
||||
}
|
||||
117
object_detection/core/models/py_utils/data_parallel.py
Normal file
117
object_detection/core/models/py_utils/data_parallel.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
from torch.nn.modules import Module
|
||||
from torch.nn.parallel.parallel_apply import parallel_apply
|
||||
from torch.nn.parallel.replicate import replicate
|
||||
from torch.nn.parallel.scatter_gather import gather
|
||||
|
||||
from .scatter_gather import scatter_kwargs
|
||||
|
||||
|
||||
class DataParallel(Module):
|
||||
r"""Implements data parallelism at the module level.
|
||||
|
||||
This container parallelizes the application of the given module by
|
||||
splitting the input across the specified devices by chunking in the batch
|
||||
dimension. In the forward pass, the module is replicated on each device,
|
||||
and each replica handles a portion of the input. During the backwards
|
||||
pass, gradients from each replica are summed into the original module.
|
||||
|
||||
The batch size should be larger than the number of GPUs used. It should
|
||||
also be an integer multiple of the number of GPUs so that each chunk is the
|
||||
same size (so that each GPU processes the same number of samples).
|
||||
|
||||
See also: :ref:`cuda-nn-dataparallel-instead`
|
||||
|
||||
Arbitrary positional and keyword inputs are allowed to be passed into
|
||||
DataParallel EXCEPT Tensors. All variables will be scattered on dim
|
||||
specified (default 0). Primitive types will be broadcasted, but all
|
||||
other types will be a shallow copy and can be corrupted if written to in
|
||||
the model's forward pass.
|
||||
|
||||
Args:
|
||||
module: module to be parallelized
|
||||
device_ids: CUDA devices (default: all devices)
|
||||
output_device: device location of output (default: device_ids[0])
|
||||
|
||||
Example::
|
||||
|
||||
>>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2])
|
||||
>>> output = net(input_var)
|
||||
"""
|
||||
|
||||
# TODO: update notes/cuda.rst when this class handles 8+ GPUs well
|
||||
|
||||
def __init__(self, module, device_ids=None, output_device=None, dim=0, chunk_sizes=None):
|
||||
super(DataParallel, self).__init__()
|
||||
|
||||
if not torch.cuda.is_available():
|
||||
self.module = module
|
||||
self.device_ids = []
|
||||
return
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
self.dim = dim
|
||||
self.module = module
|
||||
self.device_ids = device_ids
|
||||
self.chunk_sizes = chunk_sizes
|
||||
self.output_device = output_device
|
||||
if len(self.device_ids) == 1:
|
||||
self.module.cuda(device_ids[0])
|
||||
|
||||
def forward(self, *inputs, **kwargs):
|
||||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids, self.chunk_sizes)
|
||||
if len(self.device_ids) == 1:
|
||||
return self.module(*inputs[0], **kwargs[0])
|
||||
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
|
||||
outputs = self.parallel_apply(replicas, inputs, kwargs)
|
||||
return self.gather(outputs, self.output_device)
|
||||
|
||||
def replicate(self, module, device_ids):
|
||||
return replicate(module, device_ids)
|
||||
|
||||
def scatter(self, inputs, kwargs, device_ids, chunk_sizes):
|
||||
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim, chunk_sizes=self.chunk_sizes)
|
||||
|
||||
def parallel_apply(self, replicas, inputs, kwargs):
|
||||
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
|
||||
|
||||
def gather(self, outputs, output_device):
|
||||
return gather(outputs, output_device, dim=self.dim)
|
||||
|
||||
|
||||
def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
|
||||
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
|
||||
|
||||
This is the functional version of the DataParallel module.
|
||||
|
||||
Args:
|
||||
module: the module to evaluate in parallel
|
||||
inputs: inputs to the module
|
||||
device_ids: GPU ids on which to replicate module
|
||||
output_device: GPU location of the output Use -1 to indicate the CPU.
|
||||
(default: device_ids[0])
|
||||
Returns:
|
||||
a Variable containing the result of module(input) located on
|
||||
output_device
|
||||
"""
|
||||
if not isinstance(inputs, tuple):
|
||||
inputs = (inputs,)
|
||||
|
||||
if device_ids is None:
|
||||
device_ids = list(range(torch.cuda.device_count()))
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
|
||||
if len(device_ids) == 1:
|
||||
return module(*inputs[0], **module_kwargs[0])
|
||||
used_device_ids = device_ids[:len(inputs)]
|
||||
replicas = replicate(module, used_device_ids)
|
||||
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
|
||||
return gather(outputs, output_device, dim)
|
||||
231
object_detection/core/models/py_utils/losses.py
Normal file
231
object_detection/core/models/py_utils/losses.py
Normal file
@@ -0,0 +1,231 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import _tranpose_and_gather_feat
|
||||
|
||||
|
||||
def _sigmoid(x):
|
||||
return torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4)
|
||||
|
||||
|
||||
def _ae_loss(tag0, tag1, mask):
|
||||
num = mask.sum(dim=1, keepdim=True).float()
|
||||
tag0 = tag0.squeeze()
|
||||
tag1 = tag1.squeeze()
|
||||
|
||||
tag_mean = (tag0 + tag1) / 2
|
||||
|
||||
tag0 = torch.pow(tag0 - tag_mean, 2) / (num + 1e-4)
|
||||
tag0 = tag0[mask].sum()
|
||||
tag1 = torch.pow(tag1 - tag_mean, 2) / (num + 1e-4)
|
||||
tag1 = tag1[mask].sum()
|
||||
pull = tag0 + tag1
|
||||
|
||||
mask = mask.unsqueeze(1) + mask.unsqueeze(2)
|
||||
mask = mask.eq(2)
|
||||
num = num.unsqueeze(2)
|
||||
num2 = (num - 1) * num
|
||||
dist = tag_mean.unsqueeze(1) - tag_mean.unsqueeze(2)
|
||||
dist = 1 - torch.abs(dist)
|
||||
dist = nn.functional.relu(dist, inplace=True)
|
||||
dist = dist - 1 / (num + 1e-4)
|
||||
dist = dist / (num2 + 1e-4)
|
||||
dist = dist[mask]
|
||||
push = dist.sum()
|
||||
return pull, push
|
||||
|
||||
|
||||
def _off_loss(off, gt_off, mask):
|
||||
num = mask.float().sum()
|
||||
mask = mask.unsqueeze(2).expand_as(gt_off)
|
||||
|
||||
off = off[mask]
|
||||
gt_off = gt_off[mask]
|
||||
|
||||
off_loss = nn.functional.smooth_l1_loss(off, gt_off, reduction="sum")
|
||||
off_loss = off_loss / (num + 1e-4)
|
||||
return off_loss
|
||||
|
||||
|
||||
def _focal_loss_mask(preds, gt, mask):
|
||||
pos_inds = gt.eq(1)
|
||||
neg_inds = gt.lt(1)
|
||||
|
||||
neg_weights = torch.pow(1 - gt[neg_inds], 4)
|
||||
|
||||
pos_mask = mask[pos_inds]
|
||||
neg_mask = mask[neg_inds]
|
||||
|
||||
loss = 0
|
||||
for pred in preds:
|
||||
pos_pred = pred[pos_inds]
|
||||
neg_pred = pred[neg_inds]
|
||||
|
||||
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2) * pos_mask
|
||||
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights * neg_mask
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
if pos_pred.nelement() == 0:
|
||||
loss = loss - neg_loss
|
||||
else:
|
||||
loss = loss - (pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
def _focal_loss(preds, gt):
|
||||
pos_inds = gt.eq(1)
|
||||
neg_inds = gt.lt(1)
|
||||
|
||||
neg_weights = torch.pow(1 - gt[neg_inds], 4)
|
||||
|
||||
loss = 0
|
||||
for pred in preds:
|
||||
pos_pred = pred[pos_inds]
|
||||
neg_pred = pred[neg_inds]
|
||||
|
||||
pos_loss = torch.log(pos_pred) * torch.pow(1 - pos_pred, 2)
|
||||
neg_loss = torch.log(1 - neg_pred) * torch.pow(neg_pred, 2) * neg_weights
|
||||
|
||||
num_pos = pos_inds.float().sum()
|
||||
pos_loss = pos_loss.sum()
|
||||
neg_loss = neg_loss.sum()
|
||||
|
||||
if pos_pred.nelement() == 0:
|
||||
loss = loss - neg_loss
|
||||
else:
|
||||
loss = loss - (pos_loss + neg_loss) / num_pos
|
||||
return loss
|
||||
|
||||
|
||||
class CornerNet_Saccade_Loss(nn.Module):
|
||||
def __init__(self, pull_weight=1, push_weight=1, off_weight=1, focal_loss=_focal_loss_mask):
|
||||
super(CornerNet_Saccade_Loss, self).__init__()
|
||||
|
||||
self.pull_weight = pull_weight
|
||||
self.push_weight = push_weight
|
||||
self.off_weight = off_weight
|
||||
self.focal_loss = focal_loss
|
||||
self.ae_loss = _ae_loss
|
||||
self.off_loss = _off_loss
|
||||
|
||||
def forward(self, outs, targets):
|
||||
tl_heats = outs[0]
|
||||
br_heats = outs[1]
|
||||
tl_tags = outs[2]
|
||||
br_tags = outs[3]
|
||||
tl_offs = outs[4]
|
||||
br_offs = outs[5]
|
||||
atts = outs[6]
|
||||
|
||||
gt_tl_heat = targets[0]
|
||||
gt_br_heat = targets[1]
|
||||
gt_mask = targets[2]
|
||||
gt_tl_off = targets[3]
|
||||
gt_br_off = targets[4]
|
||||
gt_tl_ind = targets[5]
|
||||
gt_br_ind = targets[6]
|
||||
gt_tl_valid = targets[7]
|
||||
gt_br_valid = targets[8]
|
||||
gt_atts = targets[9]
|
||||
|
||||
# focal loss
|
||||
focal_loss = 0
|
||||
|
||||
tl_heats = [_sigmoid(t) for t in tl_heats]
|
||||
br_heats = [_sigmoid(b) for b in br_heats]
|
||||
|
||||
focal_loss += self.focal_loss(tl_heats, gt_tl_heat, gt_tl_valid)
|
||||
focal_loss += self.focal_loss(br_heats, gt_br_heat, gt_br_valid)
|
||||
|
||||
atts = [[_sigmoid(a) for a in att] for att in atts]
|
||||
atts = [[att[ind] for att in atts] for ind in range(len(gt_atts))]
|
||||
|
||||
att_loss = 0
|
||||
for att, gt_att in zip(atts, gt_atts):
|
||||
att_loss += _focal_loss(att, gt_att) / max(len(att), 1)
|
||||
|
||||
# tag loss
|
||||
pull_loss = 0
|
||||
push_loss = 0
|
||||
tl_tags = [_tranpose_and_gather_feat(tl_tag, gt_tl_ind) for tl_tag in tl_tags]
|
||||
br_tags = [_tranpose_and_gather_feat(br_tag, gt_br_ind) for br_tag in br_tags]
|
||||
for tl_tag, br_tag in zip(tl_tags, br_tags):
|
||||
pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
|
||||
pull_loss += pull
|
||||
push_loss += push
|
||||
pull_loss = self.pull_weight * pull_loss
|
||||
push_loss = self.push_weight * push_loss
|
||||
|
||||
off_loss = 0
|
||||
tl_offs = [_tranpose_and_gather_feat(tl_off, gt_tl_ind) for tl_off in tl_offs]
|
||||
br_offs = [_tranpose_and_gather_feat(br_off, gt_br_ind) for br_off in br_offs]
|
||||
for tl_off, br_off in zip(tl_offs, br_offs):
|
||||
off_loss += self.off_loss(tl_off, gt_tl_off, gt_mask)
|
||||
off_loss += self.off_loss(br_off, gt_br_off, gt_mask)
|
||||
off_loss = self.off_weight * off_loss
|
||||
|
||||
loss = (focal_loss + att_loss + pull_loss + push_loss + off_loss) / max(len(tl_heats), 1)
|
||||
return loss.unsqueeze(0)
|
||||
|
||||
|
||||
class CornerNet_Loss(nn.Module):
|
||||
def __init__(self, pull_weight=1, push_weight=1, off_weight=1, focal_loss=_focal_loss):
|
||||
super(CornerNet_Loss, self).__init__()
|
||||
|
||||
self.pull_weight = pull_weight
|
||||
self.push_weight = push_weight
|
||||
self.off_weight = off_weight
|
||||
self.focal_loss = focal_loss
|
||||
self.ae_loss = _ae_loss
|
||||
self.off_loss = _off_loss
|
||||
|
||||
def forward(self, outs, targets):
|
||||
tl_heats = outs[0]
|
||||
br_heats = outs[1]
|
||||
tl_tags = outs[2]
|
||||
br_tags = outs[3]
|
||||
tl_offs = outs[4]
|
||||
br_offs = outs[5]
|
||||
|
||||
gt_tl_heat = targets[0]
|
||||
gt_br_heat = targets[1]
|
||||
gt_mask = targets[2]
|
||||
gt_tl_off = targets[3]
|
||||
gt_br_off = targets[4]
|
||||
gt_tl_ind = targets[5]
|
||||
gt_br_ind = targets[6]
|
||||
|
||||
# focal loss
|
||||
focal_loss = 0
|
||||
|
||||
tl_heats = [_sigmoid(t) for t in tl_heats]
|
||||
br_heats = [_sigmoid(b) for b in br_heats]
|
||||
|
||||
focal_loss += self.focal_loss(tl_heats, gt_tl_heat)
|
||||
focal_loss += self.focal_loss(br_heats, gt_br_heat)
|
||||
|
||||
# tag loss
|
||||
pull_loss = 0
|
||||
push_loss = 0
|
||||
tl_tags = [_tranpose_and_gather_feat(tl_tag, gt_tl_ind) for tl_tag in tl_tags]
|
||||
br_tags = [_tranpose_and_gather_feat(br_tag, gt_br_ind) for br_tag in br_tags]
|
||||
for tl_tag, br_tag in zip(tl_tags, br_tags):
|
||||
pull, push = self.ae_loss(tl_tag, br_tag, gt_mask)
|
||||
pull_loss += pull
|
||||
push_loss += push
|
||||
pull_loss = self.pull_weight * pull_loss
|
||||
push_loss = self.push_weight * push_loss
|
||||
|
||||
off_loss = 0
|
||||
tl_offs = [_tranpose_and_gather_feat(tl_off, gt_tl_ind) for tl_off in tl_offs]
|
||||
br_offs = [_tranpose_and_gather_feat(br_off, gt_br_ind) for br_off in br_offs]
|
||||
for tl_off, br_off in zip(tl_offs, br_offs):
|
||||
off_loss += self.off_loss(tl_off, gt_tl_off, gt_mask)
|
||||
off_loss += self.off_loss(br_off, gt_br_off, gt_mask)
|
||||
off_loss = self.off_weight * off_loss
|
||||
|
||||
loss = (focal_loss + pull_loss + push_loss + off_loss) / max(len(tl_heats), 1)
|
||||
return loss.unsqueeze(0)
|
||||
303
object_detection/core/models/py_utils/modules.py
Normal file
303
object_detection/core/models/py_utils/modules.py
Normal file
@@ -0,0 +1,303 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import residual, upsample, merge, _decode
|
||||
|
||||
|
||||
def _make_layer(inp_dim, out_dim, modules):
|
||||
layers = [residual(inp_dim, out_dim)]
|
||||
layers += [residual(out_dim, out_dim) for _ in range(1, modules)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def _make_layer_revr(inp_dim, out_dim, modules):
|
||||
layers = [residual(inp_dim, inp_dim) for _ in range(modules - 1)]
|
||||
layers += [residual(inp_dim, out_dim)]
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
def _make_pool_layer(dim):
|
||||
return nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
|
||||
|
||||
def _make_unpool_layer(dim):
|
||||
return upsample(scale_factor=2)
|
||||
|
||||
|
||||
def _make_merge_layer(dim):
|
||||
return merge()
|
||||
|
||||
|
||||
class hg_module(nn.Module):
|
||||
def __init__(
|
||||
self, n, dims, modules, make_up_layer=_make_layer,
|
||||
make_pool_layer=_make_pool_layer, make_hg_layer=_make_layer,
|
||||
make_low_layer=_make_layer, make_hg_layer_revr=_make_layer_revr,
|
||||
make_unpool_layer=_make_unpool_layer, make_merge_layer=_make_merge_layer
|
||||
):
|
||||
super(hg_module, self).__init__()
|
||||
|
||||
curr_mod = modules[0]
|
||||
next_mod = modules[1]
|
||||
|
||||
curr_dim = dims[0]
|
||||
next_dim = dims[1]
|
||||
|
||||
self.n = n
|
||||
self.up1 = make_up_layer(curr_dim, curr_dim, curr_mod)
|
||||
self.max1 = make_pool_layer(curr_dim)
|
||||
self.low1 = make_hg_layer(curr_dim, next_dim, curr_mod)
|
||||
self.low2 = hg_module(
|
||||
n - 1, dims[1:], modules[1:],
|
||||
make_up_layer=make_up_layer,
|
||||
make_pool_layer=make_pool_layer,
|
||||
make_hg_layer=make_hg_layer,
|
||||
make_low_layer=make_low_layer,
|
||||
make_hg_layer_revr=make_hg_layer_revr,
|
||||
make_unpool_layer=make_unpool_layer,
|
||||
make_merge_layer=make_merge_layer
|
||||
) if n > 1 else make_low_layer(next_dim, next_dim, next_mod)
|
||||
self.low3 = make_hg_layer_revr(next_dim, curr_dim, curr_mod)
|
||||
self.up2 = make_unpool_layer(curr_dim)
|
||||
self.merg = make_merge_layer(curr_dim)
|
||||
|
||||
def forward(self, x):
|
||||
up1 = self.up1(x)
|
||||
max1 = self.max1(x)
|
||||
low1 = self.low1(max1)
|
||||
low2 = self.low2(low1)
|
||||
low3 = self.low3(low2)
|
||||
up2 = self.up2(low3)
|
||||
merg = self.merg(up1, up2)
|
||||
return merg
|
||||
|
||||
|
||||
class hg(nn.Module):
|
||||
def __init__(self, pre, hg_modules, cnvs, inters, cnvs_, inters_):
|
||||
super(hg, self).__init__()
|
||||
|
||||
self.pre = pre
|
||||
self.hgs = hg_modules
|
||||
self.cnvs = cnvs
|
||||
|
||||
self.inters = inters
|
||||
self.inters_ = inters_
|
||||
self.cnvs_ = cnvs_
|
||||
|
||||
def forward(self, x):
|
||||
inter = self.pre(x)
|
||||
|
||||
cnvs = []
|
||||
for ind, (hg_, cnv_) in enumerate(zip(self.hgs, self.cnvs)):
|
||||
hg = hg_(inter)
|
||||
cnv = cnv_(hg)
|
||||
cnvs.append(cnv)
|
||||
|
||||
if ind < len(self.hgs) - 1:
|
||||
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
|
||||
inter = nn.functional.relu_(inter)
|
||||
inter = self.inters[ind](inter)
|
||||
return cnvs
|
||||
|
||||
|
||||
class hg_net(nn.Module):
|
||||
def __init__(
|
||||
self, hg, tl_modules, br_modules, tl_heats, br_heats,
|
||||
tl_tags, br_tags, tl_offs, br_offs
|
||||
):
|
||||
super(hg_net, self).__init__()
|
||||
|
||||
self._decode = _decode
|
||||
|
||||
self.hg = hg
|
||||
|
||||
self.tl_modules = tl_modules
|
||||
self.br_modules = br_modules
|
||||
|
||||
self.tl_heats = tl_heats
|
||||
self.br_heats = br_heats
|
||||
|
||||
self.tl_tags = tl_tags
|
||||
self.br_tags = br_tags
|
||||
|
||||
self.tl_offs = tl_offs
|
||||
self.br_offs = br_offs
|
||||
|
||||
def _train(self, *xs):
|
||||
image = xs[0]
|
||||
cnvs = self.hg(image)
|
||||
|
||||
tl_modules = [tl_mod_(cnv) for tl_mod_, cnv in zip(self.tl_modules, cnvs)]
|
||||
br_modules = [br_mod_(cnv) for br_mod_, cnv in zip(self.br_modules, cnvs)]
|
||||
tl_heats = [tl_heat_(tl_mod) for tl_heat_, tl_mod in zip(self.tl_heats, tl_modules)]
|
||||
br_heats = [br_heat_(br_mod) for br_heat_, br_mod in zip(self.br_heats, br_modules)]
|
||||
tl_tags = [tl_tag_(tl_mod) for tl_tag_, tl_mod in zip(self.tl_tags, tl_modules)]
|
||||
br_tags = [br_tag_(br_mod) for br_tag_, br_mod in zip(self.br_tags, br_modules)]
|
||||
tl_offs = [tl_off_(tl_mod) for tl_off_, tl_mod in zip(self.tl_offs, tl_modules)]
|
||||
br_offs = [br_off_(br_mod) for br_off_, br_mod in zip(self.br_offs, br_modules)]
|
||||
return [tl_heats, br_heats, tl_tags, br_tags, tl_offs, br_offs]
|
||||
|
||||
def _test(self, *xs, **kwargs):
|
||||
image = xs[0]
|
||||
cnvs = self.hg(image)
|
||||
|
||||
tl_mod = self.tl_modules[-1](cnvs[-1])
|
||||
br_mod = self.br_modules[-1](cnvs[-1])
|
||||
|
||||
tl_heat, br_heat = self.tl_heats[-1](tl_mod), self.br_heats[-1](br_mod)
|
||||
tl_tag, br_tag = self.tl_tags[-1](tl_mod), self.br_tags[-1](br_mod)
|
||||
tl_off, br_off = self.tl_offs[-1](tl_mod), self.br_offs[-1](br_mod)
|
||||
|
||||
outs = [tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off]
|
||||
return self._decode(*outs, **kwargs), tl_heat, br_heat, tl_tag, br_tag
|
||||
|
||||
def forward(self, *xs, test=False, **kwargs):
|
||||
if not test:
|
||||
return self._train(*xs, **kwargs)
|
||||
return self._test(*xs, **kwargs)
|
||||
|
||||
|
||||
class saccade_module(nn.Module):
|
||||
def __init__(
|
||||
self, n, dims, modules, make_up_layer=_make_layer,
|
||||
make_pool_layer=_make_pool_layer, make_hg_layer=_make_layer,
|
||||
make_low_layer=_make_layer, make_hg_layer_revr=_make_layer_revr,
|
||||
make_unpool_layer=_make_unpool_layer, make_merge_layer=_make_merge_layer
|
||||
):
|
||||
super(saccade_module, self).__init__()
|
||||
|
||||
curr_mod = modules[0]
|
||||
next_mod = modules[1]
|
||||
|
||||
curr_dim = dims[0]
|
||||
next_dim = dims[1]
|
||||
|
||||
self.n = n
|
||||
self.up1 = make_up_layer(curr_dim, curr_dim, curr_mod)
|
||||
self.max1 = make_pool_layer(curr_dim)
|
||||
self.low1 = make_hg_layer(curr_dim, next_dim, curr_mod)
|
||||
self.low2 = saccade_module(
|
||||
n - 1, dims[1:], modules[1:],
|
||||
make_up_layer=make_up_layer,
|
||||
make_pool_layer=make_pool_layer,
|
||||
make_hg_layer=make_hg_layer,
|
||||
make_low_layer=make_low_layer,
|
||||
make_hg_layer_revr=make_hg_layer_revr,
|
||||
make_unpool_layer=make_unpool_layer,
|
||||
make_merge_layer=make_merge_layer
|
||||
) if n > 1 else make_low_layer(next_dim, next_dim, next_mod)
|
||||
self.low3 = make_hg_layer_revr(next_dim, curr_dim, curr_mod)
|
||||
self.up2 = make_unpool_layer(curr_dim)
|
||||
self.merg = make_merge_layer(curr_dim)
|
||||
|
||||
def forward(self, x):
|
||||
up1 = self.up1(x)
|
||||
max1 = self.max1(x)
|
||||
low1 = self.low1(max1)
|
||||
if self.n > 1:
|
||||
low2, mergs = self.low2(low1)
|
||||
else:
|
||||
low2, mergs = self.low2(low1), []
|
||||
low3 = self.low3(low2)
|
||||
up2 = self.up2(low3)
|
||||
merg = self.merg(up1, up2)
|
||||
mergs.append(merg)
|
||||
return merg, mergs
|
||||
|
||||
|
||||
class saccade(nn.Module):
|
||||
def __init__(self, pre, hg_modules, cnvs, inters, cnvs_, inters_):
|
||||
super(saccade, self).__init__()
|
||||
|
||||
self.pre = pre
|
||||
self.hgs = hg_modules
|
||||
self.cnvs = cnvs
|
||||
|
||||
self.inters = inters
|
||||
self.inters_ = inters_
|
||||
self.cnvs_ = cnvs_
|
||||
|
||||
def forward(self, x):
|
||||
inter = self.pre(x)
|
||||
|
||||
cnvs = []
|
||||
atts = []
|
||||
for ind, (hg_, cnv_) in enumerate(zip(self.hgs, self.cnvs)):
|
||||
hg, ups = hg_(inter)
|
||||
cnv = cnv_(hg)
|
||||
cnvs.append(cnv)
|
||||
atts.append(ups)
|
||||
|
||||
if ind < len(self.hgs) - 1:
|
||||
inter = self.inters_[ind](inter) + self.cnvs_[ind](cnv)
|
||||
inter = nn.functional.relu_(inter)
|
||||
inter = self.inters[ind](inter)
|
||||
return cnvs, atts
|
||||
|
||||
|
||||
class saccade_net(nn.Module):
|
||||
def __init__(
|
||||
self, hg, tl_modules, br_modules, tl_heats, br_heats,
|
||||
tl_tags, br_tags, tl_offs, br_offs, att_modules, up_start=0
|
||||
):
|
||||
super(saccade_net, self).__init__()
|
||||
|
||||
self._decode = _decode
|
||||
|
||||
self.hg = hg
|
||||
|
||||
self.tl_modules = tl_modules
|
||||
self.br_modules = br_modules
|
||||
self.tl_heats = tl_heats
|
||||
self.br_heats = br_heats
|
||||
self.tl_tags = tl_tags
|
||||
self.br_tags = br_tags
|
||||
self.tl_offs = tl_offs
|
||||
self.br_offs = br_offs
|
||||
|
||||
self.att_modules = att_modules
|
||||
self.up_start = up_start
|
||||
|
||||
def _train(self, *xs):
|
||||
image = xs[0]
|
||||
|
||||
cnvs, ups = self.hg(image)
|
||||
ups = [up[self.up_start:] for up in ups]
|
||||
|
||||
tl_modules = [tl_mod_(cnv) for tl_mod_, cnv in zip(self.tl_modules, cnvs)]
|
||||
br_modules = [br_mod_(cnv) for br_mod_, cnv in zip(self.br_modules, cnvs)]
|
||||
tl_heats = [tl_heat_(tl_mod) for tl_heat_, tl_mod in zip(self.tl_heats, tl_modules)]
|
||||
br_heats = [br_heat_(br_mod) for br_heat_, br_mod in zip(self.br_heats, br_modules)]
|
||||
tl_tags = [tl_tag_(tl_mod) for tl_tag_, tl_mod in zip(self.tl_tags, tl_modules)]
|
||||
br_tags = [br_tag_(br_mod) for br_tag_, br_mod in zip(self.br_tags, br_modules)]
|
||||
tl_offs = [tl_off_(tl_mod) for tl_off_, tl_mod in zip(self.tl_offs, tl_modules)]
|
||||
br_offs = [br_off_(br_mod) for br_off_, br_mod in zip(self.br_offs, br_modules)]
|
||||
atts = [[att_mod_(u) for att_mod_, u in zip(att_mods, up)] for att_mods, up in zip(self.att_modules, ups)]
|
||||
return [tl_heats, br_heats, tl_tags, br_tags, tl_offs, br_offs, atts]
|
||||
|
||||
def _test(self, *xs, no_att=False, **kwargs):
|
||||
image = xs[0]
|
||||
cnvs, ups = self.hg(image)
|
||||
ups = [up[self.up_start:] for up in ups]
|
||||
|
||||
if not no_att:
|
||||
atts = [att_mod_(up) for att_mod_, up in zip(self.att_modules[-1], ups[-1])]
|
||||
atts = [torch.sigmoid(att) for att in atts]
|
||||
|
||||
tl_mod = self.tl_modules[-1](cnvs[-1])
|
||||
br_mod = self.br_modules[-1](cnvs[-1])
|
||||
|
||||
tl_heat, br_heat = self.tl_heats[-1](tl_mod), self.br_heats[-1](br_mod)
|
||||
tl_tag, br_tag = self.tl_tags[-1](tl_mod), self.br_tags[-1](br_mod)
|
||||
tl_off, br_off = self.tl_offs[-1](tl_mod), self.br_offs[-1](br_mod)
|
||||
|
||||
outs = [tl_heat, br_heat, tl_tag, br_tag, tl_off, br_off]
|
||||
if not no_att:
|
||||
return self._decode(*outs, **kwargs), atts
|
||||
else:
|
||||
return self._decode(*outs, **kwargs)
|
||||
|
||||
def forward(self, *xs, test=False, **kwargs):
|
||||
if not test:
|
||||
return self._train(*xs, **kwargs)
|
||||
return self._test(*xs, **kwargs)
|
||||
39
object_detection/core/models/py_utils/scatter_gather.py
Normal file
39
object_detection/core/models/py_utils/scatter_gather.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import torch
|
||||
from torch.autograd import Variable
|
||||
from torch.nn.parallel._functions import Scatter
|
||||
|
||||
|
||||
def scatter(inputs, target_gpus, dim=0, chunk_sizes=None):
|
||||
r"""
|
||||
Slices variables into approximately equal chunks and
|
||||
distributes them across given GPUs. Duplicates
|
||||
references to objects that are not variables. Does not
|
||||
support Tensors.
|
||||
"""
|
||||
|
||||
def scatter_map(obj):
|
||||
if isinstance(obj, Variable):
|
||||
return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
|
||||
assert not torch.is_tensor(obj), "Tensors not supported in scatter."
|
||||
if isinstance(obj, tuple):
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list):
|
||||
return list(map(list, zip(*map(scatter_map, obj))))
|
||||
if isinstance(obj, dict):
|
||||
return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
|
||||
return [obj for targets in target_gpus]
|
||||
|
||||
return scatter_map(inputs)
|
||||
|
||||
|
||||
def scatter_kwargs(inputs, kwargs, target_gpus, dim=0, chunk_sizes=None):
|
||||
r"""Scatter with support for kwargs dictionary"""
|
||||
inputs = scatter(inputs, target_gpus, dim, chunk_sizes) if inputs else []
|
||||
kwargs = scatter(kwargs, target_gpus, dim, chunk_sizes) if kwargs else []
|
||||
if len(inputs) < len(kwargs):
|
||||
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
|
||||
elif len(kwargs) < len(inputs):
|
||||
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
|
||||
inputs = tuple(inputs)
|
||||
kwargs = tuple(kwargs)
|
||||
return inputs, kwargs
|
||||
236
object_detection/core/models/py_utils/utils.py
Normal file
236
object_detection/core/models/py_utils/utils.py
Normal file
@@ -0,0 +1,236 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def _gather_feat(feat, ind, mask=None):
|
||||
dim = feat.size(2)
|
||||
ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim)
|
||||
feat = feat.gather(1, ind)
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(2).expand_as(feat)
|
||||
feat = feat[mask]
|
||||
feat = feat.view(-1, dim)
|
||||
return feat
|
||||
|
||||
|
||||
def _nms(heat, kernel=1):
|
||||
pad = (kernel - 1) // 2
|
||||
|
||||
hmax = nn.functional.max_pool2d(heat, (kernel, kernel), stride=1, padding=pad)
|
||||
keep = (hmax == heat).float()
|
||||
return heat * keep
|
||||
|
||||
|
||||
def _tranpose_and_gather_feat(feat, ind):
|
||||
feat = feat.permute(0, 2, 3, 1).contiguous()
|
||||
feat = feat.view(feat.size(0), -1, feat.size(3))
|
||||
feat = _gather_feat(feat, ind)
|
||||
return feat
|
||||
|
||||
|
||||
def _topk(scores, K=20):
|
||||
batch, cat, height, width = scores.size()
|
||||
|
||||
topk_scores, topk_inds = torch.topk(scores.view(batch, -1), K)
|
||||
|
||||
topk_clses = (topk_inds / (height * width)).int()
|
||||
|
||||
topk_inds = topk_inds % (height * width)
|
||||
topk_ys = (topk_inds / width).int().float()
|
||||
topk_xs = (topk_inds % width).int().float()
|
||||
return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
|
||||
|
||||
|
||||
def _decode(
|
||||
tl_heat, br_heat, tl_tag, br_tag, tl_regr, br_regr,
|
||||
K=100, kernel=1, ae_threshold=1, num_dets=1000, no_border=False
|
||||
):
|
||||
batch, cat, height, width = tl_heat.size()
|
||||
|
||||
tl_heat = torch.sigmoid(tl_heat)
|
||||
br_heat = torch.sigmoid(br_heat)
|
||||
|
||||
# perform nms on heatmaps
|
||||
tl_heat = _nms(tl_heat, kernel=kernel)
|
||||
br_heat = _nms(br_heat, kernel=kernel)
|
||||
|
||||
tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = _topk(tl_heat, K=K)
|
||||
br_scores, br_inds, br_clses, br_ys, br_xs = _topk(br_heat, K=K)
|
||||
|
||||
tl_ys = tl_ys.view(batch, K, 1).expand(batch, K, K)
|
||||
tl_xs = tl_xs.view(batch, K, 1).expand(batch, K, K)
|
||||
br_ys = br_ys.view(batch, 1, K).expand(batch, K, K)
|
||||
br_xs = br_xs.view(batch, 1, K).expand(batch, K, K)
|
||||
|
||||
if no_border:
|
||||
tl_ys_binds = (tl_ys == 0)
|
||||
tl_xs_binds = (tl_xs == 0)
|
||||
br_ys_binds = (br_ys == height - 1)
|
||||
br_xs_binds = (br_xs == width - 1)
|
||||
|
||||
if tl_regr is not None and br_regr is not None:
|
||||
tl_regr = _tranpose_and_gather_feat(tl_regr, tl_inds)
|
||||
tl_regr = tl_regr.view(batch, K, 1, 2)
|
||||
br_regr = _tranpose_and_gather_feat(br_regr, br_inds)
|
||||
br_regr = br_regr.view(batch, 1, K, 2)
|
||||
|
||||
tl_xs = tl_xs + tl_regr[..., 0]
|
||||
tl_ys = tl_ys + tl_regr[..., 1]
|
||||
br_xs = br_xs + br_regr[..., 0]
|
||||
br_ys = br_ys + br_regr[..., 1]
|
||||
|
||||
# all possible boxes based on top k corners (ignoring class)
|
||||
bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
|
||||
|
||||
tl_tag = _tranpose_and_gather_feat(tl_tag, tl_inds)
|
||||
tl_tag = tl_tag.view(batch, K, 1)
|
||||
br_tag = _tranpose_and_gather_feat(br_tag, br_inds)
|
||||
br_tag = br_tag.view(batch, 1, K)
|
||||
dists = torch.abs(tl_tag - br_tag)
|
||||
|
||||
tl_scores = tl_scores.view(batch, K, 1).expand(batch, K, K)
|
||||
br_scores = br_scores.view(batch, 1, K).expand(batch, K, K)
|
||||
scores = (tl_scores + br_scores) / 2
|
||||
|
||||
# reject boxes based on classes
|
||||
tl_clses = tl_clses.view(batch, K, 1).expand(batch, K, K)
|
||||
br_clses = br_clses.view(batch, 1, K).expand(batch, K, K)
|
||||
cls_inds = (tl_clses != br_clses)
|
||||
|
||||
# reject boxes based on distances
|
||||
dist_inds = (dists > ae_threshold)
|
||||
|
||||
# reject boxes based on widths and heights
|
||||
width_inds = (br_xs < tl_xs)
|
||||
height_inds = (br_ys < tl_ys)
|
||||
|
||||
if no_border:
|
||||
scores[tl_ys_binds] = -1
|
||||
scores[tl_xs_binds] = -1
|
||||
scores[br_ys_binds] = -1
|
||||
scores[br_xs_binds] = -1
|
||||
|
||||
scores[cls_inds] = -1
|
||||
scores[dist_inds] = -1
|
||||
scores[width_inds] = -1
|
||||
scores[height_inds] = -1
|
||||
|
||||
scores = scores.view(batch, -1)
|
||||
scores, inds = torch.topk(scores, num_dets)
|
||||
scores = scores.unsqueeze(2)
|
||||
|
||||
bboxes = bboxes.view(batch, -1, 4)
|
||||
bboxes = _gather_feat(bboxes, inds)
|
||||
|
||||
clses = tl_clses.contiguous().view(batch, -1, 1)
|
||||
clses = _gather_feat(clses, inds).float()
|
||||
|
||||
tl_scores = tl_scores.contiguous().view(batch, -1, 1)
|
||||
tl_scores = _gather_feat(tl_scores, inds).float()
|
||||
br_scores = br_scores.contiguous().view(batch, -1, 1)
|
||||
br_scores = _gather_feat(br_scores, inds).float()
|
||||
|
||||
detections = torch.cat([bboxes, scores, tl_scores, br_scores, clses], dim=2)
|
||||
return detections
|
||||
|
||||
|
||||
class upsample(nn.Module):
|
||||
def __init__(self, scale_factor):
|
||||
super(upsample, self).__init__()
|
||||
self.scale_factor = scale_factor
|
||||
|
||||
def forward(self, x):
|
||||
return nn.functional.interpolate(x, scale_factor=self.scale_factor)
|
||||
|
||||
|
||||
class merge(nn.Module):
|
||||
def forward(self, x, y):
|
||||
return x + y
|
||||
|
||||
|
||||
class convolution(nn.Module):
|
||||
def __init__(self, k, inp_dim, out_dim, stride=1, with_bn=True):
|
||||
super(convolution, self).__init__()
|
||||
|
||||
pad = (k - 1) // 2
|
||||
self.conv = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(pad, pad), stride=(stride, stride), bias=not with_bn)
|
||||
self.bn = nn.BatchNorm2d(out_dim) if with_bn else nn.Sequential()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
conv = self.conv(x)
|
||||
bn = self.bn(conv)
|
||||
relu = self.relu(bn)
|
||||
return relu
|
||||
|
||||
|
||||
class residual(nn.Module):
|
||||
def __init__(self, inp_dim, out_dim, k=3, stride=1):
|
||||
super(residual, self).__init__()
|
||||
p = (k - 1) // 2
|
||||
|
||||
self.conv1 = nn.Conv2d(inp_dim, out_dim, (k, k), padding=(p, p), stride=(stride, stride), bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(out_dim)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = nn.Conv2d(out_dim, out_dim, (k, k), padding=(p, p), bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(out_dim)
|
||||
|
||||
self.skip = nn.Sequential(
|
||||
nn.Conv2d(inp_dim, out_dim, (1, 1), stride=(stride, stride), bias=False),
|
||||
nn.BatchNorm2d(out_dim)
|
||||
) if stride != 1 or inp_dim != out_dim else nn.Sequential()
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
conv1 = self.conv1(x)
|
||||
bn1 = self.bn1(conv1)
|
||||
relu1 = self.relu1(bn1)
|
||||
|
||||
conv2 = self.conv2(relu1)
|
||||
bn2 = self.bn2(conv2)
|
||||
|
||||
skip = self.skip(x)
|
||||
return self.relu(bn2 + skip)
|
||||
|
||||
|
||||
class corner_pool(nn.Module):
|
||||
def __init__(self, dim, pool1, pool2):
|
||||
super(corner_pool, self).__init__()
|
||||
self._init_layers(dim, pool1, pool2)
|
||||
|
||||
def _init_layers(self, dim, pool1, pool2):
|
||||
self.p1_conv1 = convolution(3, dim, 128)
|
||||
self.p2_conv1 = convolution(3, dim, 128)
|
||||
|
||||
self.p_conv1 = nn.Conv2d(128, dim, (3, 3), padding=(1, 1), bias=False)
|
||||
self.p_bn1 = nn.BatchNorm2d(dim)
|
||||
|
||||
self.conv1 = nn.Conv2d(dim, dim, (1, 1), bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(dim)
|
||||
self.relu1 = nn.ReLU(inplace=True)
|
||||
|
||||
self.conv2 = convolution(3, dim, dim)
|
||||
|
||||
self.pool1 = pool1()
|
||||
self.pool2 = pool2()
|
||||
|
||||
def forward(self, x):
|
||||
# pool 1
|
||||
p1_conv1 = self.p1_conv1(x)
|
||||
pool1 = self.pool1(p1_conv1)
|
||||
|
||||
# pool 2
|
||||
p2_conv1 = self.p2_conv1(x)
|
||||
pool2 = self.pool2(p2_conv1)
|
||||
|
||||
# pool 1 + pool 2
|
||||
p_conv1 = self.p_conv1(pool1 + pool2)
|
||||
p_bn1 = self.p_bn1(p_conv1)
|
||||
|
||||
conv1 = self.conv1(x)
|
||||
bn1 = self.bn1(conv1)
|
||||
relu1 = self.relu1(p_bn1 + bn1)
|
||||
|
||||
conv2 = self.conv2(relu1)
|
||||
return conv2
|
||||
Reference in New Issue
Block a user