DocTr去扭曲

This commit is contained in:
2024-08-21 16:03:39 +08:00
parent 0e17f2b9aa
commit add29504e2
26 changed files with 2212 additions and 469 deletions

152
doc_dewarp/weight_init.py Normal file
View File

@@ -0,0 +1,152 @@
import math
import numpy as np
import paddle
import paddle.nn.initializer as init
from scipy import special
def weight_init_(
layer, func, weight_name=None, bias_name=None, bias_value=0.0, **kwargs
):
"""
In-place params init function.
Usage:
.. code-block:: python
import paddle
import numpy as np
data = np.ones([3, 4], dtype='float32')
linear = paddle.nn.Linear(4, 4)
input = paddle.to_tensor(data)
print(linear.weight)
linear(input)
weight_init_(linear, 'Normal', 'fc_w0', 'fc_b0', std=0.01, mean=0.1)
print(linear.weight)
"""
if hasattr(layer, "weight") and layer.weight is not None:
getattr(init, func)(**kwargs)(layer.weight)
if weight_name is not None:
# override weight name
layer.weight.name = weight_name
if hasattr(layer, "bias") and layer.bias is not None:
init.Constant(bias_value)(layer.bias)
if bias_name is not None:
# override bias name
layer.bias.name = bias_name
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
print(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect."
)
with paddle.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
lower = norm_cdf((a - mean) / std)
upper = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to [2l-1, 2u-1].
tmp = np.random.uniform(
2 * lower - 1, 2 * upper - 1, size=list(tensor.shape)
).astype(np.float32)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tmp = special.erfinv(tmp)
# Transform to proper mean, std
tmp *= std * math.sqrt(2.0)
tmp += mean
# Clamp to ensure it's in the proper range
tmp = np.clip(tmp, a, b)
tensor.set_value(paddle.to_tensor(tmp))
return tensor
def _calculate_fan_in_and_fan_out(tensor):
dimensions = tensor.dim()
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor "
"with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if tensor.dim() > 2:
receptive_field_size = tensor[0][0].numel()
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def kaiming_normal_(tensor, a=0.0, mode="fan_in", nonlinearity="leaky_relu"):
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:
raise ValueError(
"Mode {} not supported, please use one of {}".format(mode, valid_modes)
)
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == "fan_in" else fan_out
def calculate_gain(nonlinearity, param=None):
linear_fns = [
"linear",
"conv1d",
"conv2d",
"conv3d",
"conv_transpose1d",
"conv_transpose2d",
"conv_transpose3d",
]
if nonlinearity in linear_fns or nonlinearity == "sigmoid":
return 1
elif nonlinearity == "tanh":
return 5.0 / 3
elif nonlinearity == "relu":
return math.sqrt(2.0)
elif nonlinearity == "leaky_relu":
if param is None:
negative_slope = 0.01
elif (
not isinstance(param, bool)
and isinstance(param, int)
or isinstance(param, float)
):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
fan = _calculate_correct_fan(tensor, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with paddle.no_grad():
paddle.nn.initializer.Normal(0, std)(tensor)
return tensor