移动paddle_detection
This commit is contained in:
@@ -0,0 +1,19 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import optimizer
|
||||
from . import ema
|
||||
|
||||
from .optimizer import *
|
||||
from .ema import *
|
||||
@@ -0,0 +1,272 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
from paddle.optimizer import AdamW
|
||||
from functools import partial
|
||||
import re
|
||||
|
||||
IS_PADDLE_LATER_2_4 = (
|
||||
int(paddle.version.major) >= 2 and
|
||||
int(paddle.version.minor) >= 4) or int(paddle.version.major) == 0
|
||||
|
||||
|
||||
def layerwise_lr_decay(decay_rate, name_dict, n_layers, param):
|
||||
"""
|
||||
Args:
|
||||
decay_rate (float):
|
||||
The layer-wise decay ratio.
|
||||
name_dict (dict):
|
||||
The keys of name_dict is dynamic name of model while the value
|
||||
of name_dict is static name.
|
||||
Use model.named_parameters() to get name_dict.
|
||||
n_layers (int):
|
||||
Total number of layers in the transformer encoder.
|
||||
"""
|
||||
ratio = 1.0
|
||||
static_name = name_dict[param.name]
|
||||
if 'blocks.' in static_name or 'layers.' in static_name:
|
||||
idx_1 = static_name.find('blocks.')
|
||||
idx_2 = static_name.find('layers.')
|
||||
assert any([x >= 0 for x in [idx_1, idx_2]]), ''
|
||||
idx = idx_1 if idx_1 >= 0 else idx_2
|
||||
# idx = re.findall('[blocks|layers]\.(\d+)\.', static_name)[0]
|
||||
|
||||
layer = int(static_name[idx:].split('.')[1])
|
||||
ratio = decay_rate**(n_layers - layer)
|
||||
|
||||
elif 'cls_token' in static_name or 'patch_embed' in static_name or 'pos_embed' in static_name:
|
||||
ratio = decay_rate**(n_layers + 1)
|
||||
|
||||
if IS_PADDLE_LATER_2_4:
|
||||
return ratio
|
||||
else:
|
||||
param.optimize_attr['learning_rate'] *= ratio
|
||||
|
||||
|
||||
class AdamWDL(AdamW):
|
||||
r"""
|
||||
The AdamWDL optimizer is implemented based on the AdamW Optimization with dynamic lr setting.
|
||||
Generally it's used for transformer model.
|
||||
|
||||
We use "layerwise_lr_decay" as default dynamic lr setting method of AdamWDL.
|
||||
“Layer-wise decay” means exponentially decaying the learning rates of individual
|
||||
layers in a top-down manner. For example, suppose the 24-th layer uses a learning
|
||||
rate l, and the Layer-wise decay rate is α, then the learning rate of layer m
|
||||
is lα^(24-m). See more details on: https://arxiv.org/abs/1906.08237.
|
||||
|
||||
.. math::
|
||||
& t = t + 1
|
||||
|
||||
& moment\_1\_out = {\beta}_1 * moment\_1 + (1 - {\beta}_1) * grad
|
||||
|
||||
& moment\_2\_out = {\beta}_2 * moment\_2 + (1 - {\beta}_2) * grad * grad
|
||||
|
||||
& learning\_rate = learning\_rate * \frac{\sqrt{1 - {\beta}_2^t}}{1 - {\beta}_1^t}
|
||||
|
||||
& param\_out = param - learning\_rate * (\frac{moment\_1}{\sqrt{moment\_2} + \epsilon} + \lambda * param)
|
||||
|
||||
Args:
|
||||
learning_rate (float|LRScheduler, optional): The learning rate used to update ``Parameter``.
|
||||
It can be a float value or a LRScheduler. The default value is 0.001.
|
||||
beta1 (float, optional): The exponential decay rate for the 1st moment estimates.
|
||||
It should be a float number or a Tensor with shape [1] and data type as float32.
|
||||
The default value is 0.9.
|
||||
beta2 (float, optional): The exponential decay rate for the 2nd moment estimates.
|
||||
It should be a float number or a Tensor with shape [1] and data type as float32.
|
||||
The default value is 0.999.
|
||||
epsilon (float, optional): A small float value for numerical stability.
|
||||
It should be a float number or a Tensor with shape [1] and data type as float32.
|
||||
The default value is 1e-08.
|
||||
parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. \
|
||||
This parameter is required in dygraph mode. \
|
||||
The default value is None in static mode, at this time all parameters will be updated.
|
||||
weight_decay (float, optional): The weight decay coefficient, it can be float or Tensor. The default value is 0.01.
|
||||
apply_decay_param_fun (function|None, optional): If it is not None,
|
||||
only tensors that makes apply_decay_param_fun(Tensor.name)==True
|
||||
will be updated. It only works when we want to specify tensors.
|
||||
Default: None.
|
||||
grad_clip (GradientClipBase, optional): Gradient cliping strategy, it's an instance of
|
||||
some derived class of ``GradientClipBase`` . There are three cliping strategies
|
||||
( :ref:`api_fluid_clip_GradientClipByGlobalNorm` , :ref:`api_fluid_clip_GradientClipByNorm` ,
|
||||
:ref:`api_fluid_clip_GradientClipByValue` ). Default None, meaning there is no gradient clipping.
|
||||
lazy_mode (bool, optional): The official Adam algorithm has two moving-average accumulators.
|
||||
The accumulators are updated at every step. Every element of the two moving-average
|
||||
is updated in both dense mode and sparse mode. If the size of parameter is very large,
|
||||
then the update may be very slow. The lazy mode only update the element that has
|
||||
gradient in current mini-batch, so it will be much more faster. But this mode has
|
||||
different semantics with the original Adam algorithm and may lead to different result.
|
||||
The default value is False.
|
||||
multi_precision (bool, optional): Whether to use multi-precision during weight updating. Default is false.
|
||||
layerwise_decay (float, optional): The layer-wise decay ratio. Defaults to 1.0.
|
||||
n_layers (int, optional): The total number of encoder layers. Defaults to 12.
|
||||
set_param_lr_fun (function|None, optional): If it's not None, set_param_lr_fun() will set the the parameter
|
||||
learning rate before it executes Adam Operator. Defaults to :ref:`layerwise_lr_decay`.
|
||||
name_dict (dict, optional): The keys of name_dict is dynamic name of model while the value
|
||||
of name_dict is static name. Use model.named_parameters() to get name_dict.
|
||||
name (str, optional): Normally there is no need for user to set this property.
|
||||
For more information, please refer to :ref:`api_guide_Name`.
|
||||
The default value is None.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
import paddle
|
||||
from paddlenlp.ops.optimizer import AdamWDL
|
||||
def simple_lr_setting(decay_rate, name_dict, n_layers, param):
|
||||
ratio = 1.0
|
||||
static_name = name_dict[param.name]
|
||||
if "weight" in static_name:
|
||||
ratio = decay_rate**0.5
|
||||
param.optimize_attr["learning_rate"] *= ratio
|
||||
|
||||
linear = paddle.nn.Linear(10, 10)
|
||||
|
||||
name_dict = dict()
|
||||
for n, p in linear.named_parameters():
|
||||
name_dict[p.name] = n
|
||||
|
||||
inp = paddle.rand([10,10], dtype="float32")
|
||||
out = linear(inp)
|
||||
loss = paddle.mean(out)
|
||||
|
||||
adamwdl = AdamWDL(
|
||||
learning_rate=1e-4,
|
||||
parameters=linear.parameters(),
|
||||
set_param_lr_fun=simple_lr_setting,
|
||||
layerwise_decay=0.8,
|
||||
name_dict=name_dict)
|
||||
|
||||
loss.backward()
|
||||
adamwdl.step()
|
||||
adamwdl.clear_grad()
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
learning_rate=0.001,
|
||||
beta1=0.9,
|
||||
beta2=0.999,
|
||||
epsilon=1e-8,
|
||||
parameters=None,
|
||||
weight_decay=0.01,
|
||||
apply_decay_param_fun=None,
|
||||
grad_clip=None,
|
||||
lazy_mode=False,
|
||||
multi_precision=False,
|
||||
layerwise_decay=1.0,
|
||||
n_layers=12,
|
||||
set_param_lr_func=None,
|
||||
name_dict=None,
|
||||
name=None):
|
||||
if not isinstance(layerwise_decay, float):
|
||||
raise TypeError("coeff should be float or Tensor.")
|
||||
self.layerwise_decay = layerwise_decay
|
||||
self.n_layers = n_layers
|
||||
self.set_param_lr_func = partial(
|
||||
set_param_lr_func, layerwise_decay, name_dict,
|
||||
n_layers) if set_param_lr_func is not None else set_param_lr_func
|
||||
|
||||
if IS_PADDLE_LATER_2_4:
|
||||
super(AdamWDL, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
epsilon=epsilon,
|
||||
grad_clip=grad_clip,
|
||||
name=name,
|
||||
apply_decay_param_fun=apply_decay_param_fun,
|
||||
weight_decay=weight_decay,
|
||||
lazy_mode=lazy_mode,
|
||||
multi_precision=multi_precision,
|
||||
lr_ratio=self.set_param_lr_func)
|
||||
else:
|
||||
super(AdamWDL, self).__init__(
|
||||
learning_rate=learning_rate,
|
||||
parameters=parameters,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
epsilon=epsilon,
|
||||
grad_clip=grad_clip,
|
||||
name=name,
|
||||
apply_decay_param_fun=apply_decay_param_fun,
|
||||
weight_decay=weight_decay,
|
||||
lazy_mode=lazy_mode,
|
||||
multi_precision=multi_precision)
|
||||
|
||||
|
||||
def _append_optimize_op(self, block, param_and_grad):
|
||||
if self.set_param_lr_func is None:
|
||||
return super(AdamWDL, self)._append_optimize_op(block, param_and_grad)
|
||||
|
||||
self._append_decoupled_weight_decay(block, param_and_grad)
|
||||
prev_lr = param_and_grad[0].optimize_attr["learning_rate"]
|
||||
self.set_param_lr_func(param_and_grad[0])
|
||||
# excute Adam op
|
||||
res = super(AdamW, self)._append_optimize_op(block, param_and_grad)
|
||||
param_and_grad[0].optimize_attr["learning_rate"] = prev_lr
|
||||
return res
|
||||
|
||||
|
||||
if not IS_PADDLE_LATER_2_4:
|
||||
AdamWDL._append_optimize_op = _append_optimize_op
|
||||
|
||||
|
||||
def build_adamwdl(model,
|
||||
lr=1e-4,
|
||||
weight_decay=0.05,
|
||||
betas=(0.9, 0.999),
|
||||
layer_decay=0.65,
|
||||
num_layers=None,
|
||||
filter_bias_and_bn=True,
|
||||
skip_decay_names=None,
|
||||
set_param_lr_func='layerwise_lr_decay'):
|
||||
|
||||
if skip_decay_names and filter_bias_and_bn:
|
||||
decay_dict = {
|
||||
param.name: not (len(param.shape) == 1 or name.endswith('.bias') or
|
||||
any([_n in name for _n in skip_decay_names]))
|
||||
for name, param in model.named_parameters()
|
||||
}
|
||||
parameters = [p for p in model.parameters()]
|
||||
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
|
||||
opt_args = dict(
|
||||
parameters=parameters, learning_rate=lr, weight_decay=weight_decay)
|
||||
|
||||
if decay_dict is not None:
|
||||
opt_args['apply_decay_param_fun'] = lambda n: decay_dict[n]
|
||||
|
||||
if isinstance(set_param_lr_func, str):
|
||||
set_param_lr_func = eval(set_param_lr_func)
|
||||
opt_args['set_param_lr_func'] = set_param_lr_func
|
||||
|
||||
opt_args['beta1'] = betas[0]
|
||||
opt_args['beta2'] = betas[1]
|
||||
|
||||
opt_args['layerwise_decay'] = layer_decay
|
||||
name_dict = {p.name: n for n, p in model.named_parameters()}
|
||||
|
||||
opt_args['name_dict'] = name_dict
|
||||
opt_args['n_layers'] = num_layers
|
||||
|
||||
optimizer = AdamWDL(**opt_args)
|
||||
|
||||
return optimizer
|
||||
195
services/paddle_services/paddle_detection/ppdet/optimizer/ema.py
Normal file
195
services/paddle_services/paddle_detection/ppdet/optimizer/ema.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import math
|
||||
import paddle
|
||||
import weakref
|
||||
from copy import deepcopy
|
||||
|
||||
from .utils import get_bn_running_state_names
|
||||
|
||||
__all__ = ['ModelEMA', 'SimpleModelEMA']
|
||||
|
||||
|
||||
class ModelEMA(object):
|
||||
"""
|
||||
Exponential Weighted Average for Deep Neutal Networks
|
||||
Args:
|
||||
model (nn.Layer): Detector of model.
|
||||
decay (int): The decay used for updating ema parameter.
|
||||
Ema's parameter are updated with the formula:
|
||||
`ema_param = decay * ema_param + (1 - decay) * cur_param`.
|
||||
Defaults is 0.9998.
|
||||
ema_decay_type (str): type in ['threshold', 'normal', 'exponential'],
|
||||
'threshold' as default.
|
||||
cycle_epoch (int): The epoch of interval to reset ema_param and
|
||||
step. Defaults is -1, which means not reset. Its function is to
|
||||
add a regular effect to ema, which is set according to experience
|
||||
and is effective when the total training epoch is large.
|
||||
ema_black_list (set|list|tuple, optional): The custom EMA black_list.
|
||||
Blacklist of weight names that will not participate in EMA
|
||||
calculation. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model,
|
||||
decay=0.9998,
|
||||
ema_decay_type='threshold',
|
||||
cycle_epoch=-1,
|
||||
ema_black_list=None,
|
||||
ema_filter_no_grad=False):
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.decay = decay
|
||||
self.ema_decay_type = ema_decay_type
|
||||
self.cycle_epoch = cycle_epoch
|
||||
self.ema_black_list = self._match_ema_black_list(
|
||||
model.state_dict().keys(), ema_black_list)
|
||||
bn_states_names = get_bn_running_state_names(model)
|
||||
if ema_filter_no_grad:
|
||||
for n, p in model.named_parameters():
|
||||
if p.stop_gradient and n not in bn_states_names:
|
||||
self.ema_black_list.add(n)
|
||||
|
||||
self.state_dict = dict()
|
||||
for k, v in model.state_dict().items():
|
||||
if k in self.ema_black_list:
|
||||
self.state_dict[k] = v
|
||||
else:
|
||||
self.state_dict[k] = paddle.zeros_like(v, dtype='float32')
|
||||
|
||||
self._model_state = {
|
||||
k: weakref.ref(p)
|
||||
for k, p in model.state_dict().items()
|
||||
}
|
||||
|
||||
def reset(self):
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
for k, v in self.state_dict.items():
|
||||
if k in self.ema_black_list:
|
||||
self.state_dict[k] = v
|
||||
else:
|
||||
self.state_dict[k] = paddle.zeros_like(v)
|
||||
|
||||
def resume(self, state_dict, step=0):
|
||||
for k, v in state_dict.items():
|
||||
if k in self.state_dict:
|
||||
if self.state_dict[k].dtype == v.dtype:
|
||||
self.state_dict[k] = v
|
||||
else:
|
||||
self.state_dict[k] = v.astype(self.state_dict[k].dtype)
|
||||
self.step = step
|
||||
|
||||
def update(self, model=None):
|
||||
if self.ema_decay_type == 'threshold':
|
||||
decay = min(self.decay, (1 + self.step) / (10 + self.step))
|
||||
elif self.ema_decay_type == 'exponential':
|
||||
decay = self.decay * (1 - math.exp(-(self.step + 1) / 2000))
|
||||
else:
|
||||
decay = self.decay
|
||||
self._decay = decay
|
||||
|
||||
if model is not None:
|
||||
model_dict = model.state_dict()
|
||||
else:
|
||||
model_dict = {k: p() for k, p in self._model_state.items()}
|
||||
assert all(
|
||||
[v is not None for _, v in model_dict.items()]), 'python gc.'
|
||||
|
||||
for k, v in self.state_dict.items():
|
||||
if k not in self.ema_black_list:
|
||||
v = decay * v + (1 - decay) * model_dict[k].astype('float32')
|
||||
v.stop_gradient = True
|
||||
self.state_dict[k] = v
|
||||
self.step += 1
|
||||
|
||||
def apply(self):
|
||||
if self.step == 0:
|
||||
return self.state_dict
|
||||
state_dict = dict()
|
||||
model_dict = {k: p() for k, p in self._model_state.items()}
|
||||
for k, v in self.state_dict.items():
|
||||
if k in self.ema_black_list:
|
||||
v.stop_gradient = True
|
||||
state_dict[k] = v
|
||||
else:
|
||||
if self.ema_decay_type != 'exponential':
|
||||
v = v / (1 - self._decay**self.step)
|
||||
v = v.astype(model_dict[k].dtype)
|
||||
v.stop_gradient = True
|
||||
state_dict[k] = v
|
||||
self.epoch += 1
|
||||
if self.cycle_epoch > 0 and self.epoch == self.cycle_epoch:
|
||||
self.reset()
|
||||
|
||||
return state_dict
|
||||
|
||||
def _match_ema_black_list(self, weight_name, ema_black_list=None):
|
||||
out_list = set()
|
||||
if ema_black_list:
|
||||
for name in weight_name:
|
||||
for key in ema_black_list:
|
||||
if key in name:
|
||||
out_list.add(name)
|
||||
return out_list
|
||||
|
||||
|
||||
class SimpleModelEMA(object):
|
||||
"""
|
||||
Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
|
||||
Keep a moving average of everything in the model state_dict (parameters and buffers).
|
||||
This is intended to allow functionality like
|
||||
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
A smoothed version of the weights is necessary for some training schemes to perform well.
|
||||
This class is sensitive where it is initialized in the sequence of model init,
|
||||
GPU assignment and distributed training wrappers.
|
||||
"""
|
||||
|
||||
def __init__(self, model=None, decay=0.9996):
|
||||
"""
|
||||
Args:
|
||||
model (nn.Module): model to apply EMA.
|
||||
decay (float): ema decay reate.
|
||||
"""
|
||||
self.model = deepcopy(model)
|
||||
self.decay = decay
|
||||
|
||||
def update(self, model, decay=None):
|
||||
if decay is None:
|
||||
decay = self.decay
|
||||
|
||||
with paddle.no_grad():
|
||||
state = {}
|
||||
msd = model.state_dict()
|
||||
for k, v in self.model.state_dict().items():
|
||||
if paddle.is_floating_point(v):
|
||||
v *= decay
|
||||
v += (1.0 - decay) * msd[k].detach()
|
||||
state[k] = v
|
||||
self.model.set_state_dict(state)
|
||||
|
||||
def resume(self, state_dict, step=0):
|
||||
state = {}
|
||||
msd = state_dict
|
||||
for k, v in self.model.state_dict().items():
|
||||
if paddle.is_floating_point(v):
|
||||
v = msd[k].detach()
|
||||
state[k] = v
|
||||
self.model.set_state_dict(state)
|
||||
self.step = step
|
||||
@@ -0,0 +1,358 @@
|
||||
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import math
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
import paddle.optimizer as optimizer
|
||||
import paddle.regularizer as regularizer
|
||||
|
||||
from ppdet.core.workspace import register, serializable
|
||||
import copy
|
||||
|
||||
from .adamw import AdamWDL, build_adamwdl
|
||||
|
||||
__all__ = ['LearningRate', 'OptimizerBuilder']
|
||||
|
||||
from ppdet.utils.logger import setup_logger
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
@serializable
|
||||
class CosineDecay(object):
|
||||
"""
|
||||
Cosine learning rate decay
|
||||
|
||||
Args:
|
||||
max_epochs (int): max epochs for the training process.
|
||||
if you commbine cosine decay with warmup, it is recommended that
|
||||
the max_iters is much larger than the warmup iter
|
||||
use_warmup (bool): whether to use warmup. Default: True.
|
||||
min_lr_ratio (float): minimum learning rate ratio. Default: 0.
|
||||
last_plateau_epochs (int): use minimum learning rate in
|
||||
the last few epochs. Default: 0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
max_epochs=1000,
|
||||
use_warmup=True,
|
||||
min_lr_ratio=0.,
|
||||
last_plateau_epochs=0):
|
||||
self.max_epochs = max_epochs
|
||||
self.use_warmup = use_warmup
|
||||
self.min_lr_ratio = min_lr_ratio
|
||||
self.last_plateau_epochs = last_plateau_epochs
|
||||
|
||||
def __call__(self,
|
||||
base_lr=None,
|
||||
boundary=None,
|
||||
value=None,
|
||||
step_per_epoch=None):
|
||||
assert base_lr is not None, "either base LR or values should be provided"
|
||||
|
||||
max_iters = self.max_epochs * int(step_per_epoch)
|
||||
last_plateau_iters = self.last_plateau_epochs * int(step_per_epoch)
|
||||
min_lr = base_lr * self.min_lr_ratio
|
||||
if boundary is not None and value is not None and self.use_warmup:
|
||||
# use warmup
|
||||
warmup_iters = len(boundary)
|
||||
for i in range(int(boundary[-1]), max_iters):
|
||||
boundary.append(i)
|
||||
if i < max_iters - last_plateau_iters:
|
||||
decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
|
||||
(i - warmup_iters) * math.pi /
|
||||
(max_iters - warmup_iters - last_plateau_iters)) + 1)
|
||||
value.append(decayed_lr)
|
||||
else:
|
||||
value.append(min_lr)
|
||||
return optimizer.lr.PiecewiseDecay(boundary, value)
|
||||
elif last_plateau_iters > 0:
|
||||
# not use warmup, but set `last_plateau_epochs` > 0
|
||||
boundary = []
|
||||
value = []
|
||||
for i in range(max_iters):
|
||||
if i < max_iters - last_plateau_iters:
|
||||
decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
|
||||
i * math.pi / (max_iters - last_plateau_iters)) + 1)
|
||||
value.append(decayed_lr)
|
||||
else:
|
||||
value.append(min_lr)
|
||||
if i > 0:
|
||||
boundary.append(i)
|
||||
return optimizer.lr.PiecewiseDecay(boundary, value)
|
||||
|
||||
return optimizer.lr.CosineAnnealingDecay(
|
||||
base_lr, T_max=max_iters, eta_min=min_lr)
|
||||
|
||||
|
||||
@serializable
|
||||
class PiecewiseDecay(object):
|
||||
"""
|
||||
Multi step learning rate decay
|
||||
|
||||
Args:
|
||||
gamma (float | list): decay factor
|
||||
milestones (list): steps at which to decay learning rate
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
gamma=[0.1, 0.01],
|
||||
milestones=[8, 11],
|
||||
values=None,
|
||||
use_warmup=True):
|
||||
super(PiecewiseDecay, self).__init__()
|
||||
if type(gamma) is not list:
|
||||
self.gamma = []
|
||||
for i in range(len(milestones)):
|
||||
self.gamma.append(gamma / 10**i)
|
||||
else:
|
||||
self.gamma = gamma
|
||||
self.milestones = milestones
|
||||
self.values = values
|
||||
self.use_warmup = use_warmup
|
||||
|
||||
def __call__(self,
|
||||
base_lr=None,
|
||||
boundary=None,
|
||||
value=None,
|
||||
step_per_epoch=None):
|
||||
if boundary is not None and self.use_warmup:
|
||||
boundary.extend([int(step_per_epoch) * i for i in self.milestones])
|
||||
else:
|
||||
# do not use LinearWarmup
|
||||
boundary = [int(step_per_epoch) * i for i in self.milestones]
|
||||
value = [base_lr] # during step[0, boundary[0]] is base_lr
|
||||
|
||||
# self.values is setted directly in config
|
||||
if self.values is not None:
|
||||
assert len(self.milestones) + 1 == len(self.values)
|
||||
return optimizer.lr.PiecewiseDecay(boundary, self.values)
|
||||
|
||||
# value is computed by self.gamma
|
||||
value = value if value is not None else [base_lr]
|
||||
for i in self.gamma:
|
||||
value.append(base_lr * i)
|
||||
|
||||
return optimizer.lr.PiecewiseDecay(boundary, value)
|
||||
|
||||
|
||||
@serializable
|
||||
class LinearWarmup(object):
|
||||
"""
|
||||
Warm up learning rate linearly
|
||||
|
||||
Args:
|
||||
steps (int): warm up steps
|
||||
start_factor (float): initial learning rate factor
|
||||
epochs (int|None): use epochs as warm up steps, the priority
|
||||
of `epochs` is higher than `steps`. Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self, steps=500, start_factor=1. / 3, epochs=None, epochs_first=True):
|
||||
super(LinearWarmup, self).__init__()
|
||||
self.steps = steps
|
||||
self.start_factor = start_factor
|
||||
self.epochs = epochs
|
||||
self.epochs_first = epochs_first
|
||||
|
||||
def __call__(self, base_lr, step_per_epoch):
|
||||
boundary = []
|
||||
value = []
|
||||
if self.epochs_first and self.epochs is not None:
|
||||
warmup_steps = self.epochs * step_per_epoch
|
||||
else:
|
||||
warmup_steps = self.steps
|
||||
warmup_steps = max(warmup_steps, 1)
|
||||
for i in range(warmup_steps + 1):
|
||||
if warmup_steps > 0:
|
||||
alpha = i / warmup_steps
|
||||
factor = self.start_factor * (1 - alpha) + alpha
|
||||
lr = base_lr * factor
|
||||
value.append(lr)
|
||||
if i > 0:
|
||||
boundary.append(i)
|
||||
return boundary, value
|
||||
|
||||
|
||||
@serializable
|
||||
class ExpWarmup(object):
|
||||
"""
|
||||
Warm up learning rate in exponential mode
|
||||
Args:
|
||||
steps (int): warm up steps.
|
||||
epochs (int|None): use epochs as warm up steps, the priority
|
||||
of `epochs` is higher than `steps`. Default: None.
|
||||
power (int): Exponential coefficient. Default: 2.
|
||||
"""
|
||||
|
||||
def __init__(self, steps=1000, epochs=None, power=2):
|
||||
super(ExpWarmup, self).__init__()
|
||||
self.steps = steps
|
||||
self.epochs = epochs
|
||||
self.power = power
|
||||
|
||||
def __call__(self, base_lr, step_per_epoch):
|
||||
boundary = []
|
||||
value = []
|
||||
warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
|
||||
warmup_steps = max(warmup_steps, 1)
|
||||
for i in range(warmup_steps + 1):
|
||||
factor = (i / float(warmup_steps))**self.power
|
||||
value.append(base_lr * factor)
|
||||
if i > 0:
|
||||
boundary.append(i)
|
||||
return boundary, value
|
||||
|
||||
|
||||
@register
|
||||
class LearningRate(object):
|
||||
"""
|
||||
Learning Rate configuration
|
||||
|
||||
Args:
|
||||
base_lr (float): base learning rate
|
||||
schedulers (list): learning rate schedulers
|
||||
"""
|
||||
__category__ = 'optim'
|
||||
|
||||
def __init__(self,
|
||||
base_lr=0.01,
|
||||
schedulers=[PiecewiseDecay(), LinearWarmup()]):
|
||||
super(LearningRate, self).__init__()
|
||||
self.base_lr = base_lr
|
||||
self.schedulers = []
|
||||
|
||||
schedulers = copy.deepcopy(schedulers)
|
||||
for sched in schedulers:
|
||||
if isinstance(sched, dict):
|
||||
# support dict sched instantiate
|
||||
module = sys.modules[__name__]
|
||||
type = sched.pop("name")
|
||||
scheduler = getattr(module, type)(**sched)
|
||||
self.schedulers.append(scheduler)
|
||||
else:
|
||||
self.schedulers.append(sched)
|
||||
|
||||
def __call__(self, step_per_epoch):
|
||||
assert len(self.schedulers) >= 1
|
||||
if not self.schedulers[0].use_warmup:
|
||||
return self.schedulers[0](base_lr=self.base_lr,
|
||||
step_per_epoch=step_per_epoch)
|
||||
|
||||
# TODO: split warmup & decay
|
||||
# warmup
|
||||
boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
|
||||
# decay
|
||||
decay_lr = self.schedulers[0](self.base_lr, boundary, value,
|
||||
step_per_epoch)
|
||||
return decay_lr
|
||||
|
||||
|
||||
@register
|
||||
class OptimizerBuilder():
|
||||
"""
|
||||
Build optimizer handles
|
||||
Args:
|
||||
regularizer (object): an `Regularizer` instance
|
||||
optimizer (object): an `Optimizer` instance
|
||||
"""
|
||||
__category__ = 'optim'
|
||||
|
||||
def __init__(self,
|
||||
clip_grad_by_norm=None,
|
||||
clip_grad_by_value=None,
|
||||
regularizer={'type': 'L2',
|
||||
'factor': .0001},
|
||||
optimizer={'type': 'Momentum',
|
||||
'momentum': .9}):
|
||||
self.clip_grad_by_norm = clip_grad_by_norm
|
||||
self.clip_grad_by_value = clip_grad_by_value
|
||||
self.regularizer = regularizer
|
||||
self.optimizer = optimizer
|
||||
|
||||
def __call__(self, learning_rate, model=None):
|
||||
if self.clip_grad_by_norm is not None:
|
||||
grad_clip = nn.ClipGradByGlobalNorm(
|
||||
clip_norm=self.clip_grad_by_norm)
|
||||
elif self.clip_grad_by_value is not None:
|
||||
var = abs(self.clip_grad_by_value)
|
||||
grad_clip = nn.ClipGradByValue(min=-var, max=var)
|
||||
else:
|
||||
grad_clip = None
|
||||
if self.regularizer and self.regularizer != 'None':
|
||||
reg_type = self.regularizer['type'] + 'Decay'
|
||||
reg_factor = self.regularizer['factor']
|
||||
regularization = getattr(regularizer, reg_type)(reg_factor)
|
||||
else:
|
||||
regularization = None
|
||||
|
||||
optim_args = self.optimizer.copy()
|
||||
optim_type = optim_args['type']
|
||||
del optim_args['type']
|
||||
|
||||
if optim_type == 'AdamWDL':
|
||||
return build_adamwdl(model, lr=learning_rate, **optim_args)
|
||||
|
||||
if optim_type != 'AdamW':
|
||||
optim_args['weight_decay'] = regularization
|
||||
|
||||
op = getattr(optimizer, optim_type)
|
||||
|
||||
if 'param_groups' in optim_args:
|
||||
assert isinstance(optim_args['param_groups'], list), ''
|
||||
|
||||
param_groups = optim_args.pop('param_groups')
|
||||
|
||||
params, visited = [], []
|
||||
for group in param_groups:
|
||||
assert isinstance(group,
|
||||
dict) and 'params' in group and isinstance(
|
||||
group['params'], list), ''
|
||||
_params = {
|
||||
n: p
|
||||
for n, p in model.named_parameters()
|
||||
if any([k in n
|
||||
for k in group['params']]) and p.trainable is True
|
||||
}
|
||||
_group = group.copy()
|
||||
_group.update({'params': list(_params.values())})
|
||||
|
||||
params.append(_group)
|
||||
visited.extend(list(_params.keys()))
|
||||
|
||||
ext_params = [
|
||||
p for n, p in model.named_parameters()
|
||||
if n not in visited and p.trainable is True
|
||||
]
|
||||
|
||||
if len(ext_params) < len(model.parameters()):
|
||||
params.append({'params': ext_params})
|
||||
|
||||
elif len(ext_params) > len(model.parameters()):
|
||||
raise RuntimeError
|
||||
|
||||
else:
|
||||
_params = model.parameters()
|
||||
params = [param for param in _params if param.trainable is True]
|
||||
|
||||
return op(learning_rate=learning_rate,
|
||||
parameters=params,
|
||||
grad_clip=grad_clip,
|
||||
**optim_args)
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import paddle
|
||||
import paddle.nn as nn
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
def get_bn_running_state_names(model: nn.Layer) -> List[str]:
|
||||
"""Get all bn state full names including running mean and variance
|
||||
"""
|
||||
names = []
|
||||
for n, m in model.named_sublayers():
|
||||
if isinstance(m, (nn.BatchNorm2D, nn.SyncBatchNorm)):
|
||||
assert hasattr(m, '_mean'), f'assert {m} has _mean'
|
||||
assert hasattr(m, '_variance'), f'assert {m} has _variance'
|
||||
running_mean = f'{n}._mean'
|
||||
running_var = f'{n}._variance'
|
||||
names.extend([running_mean, running_var])
|
||||
|
||||
return names
|
||||
Reference in New Issue
Block a user