Megatron-LM中的loss-scale
loss-scale被广泛用于混精训练中,扩大反向传播过程中的参数梯度计算。笔者进一步解读了Megatron-LM框架中的loss-scale设置到应用的完整过程,希望能加深理解。
1. loss-scale初始化
1.1 超参数初始化
Megatron中支持静态和动态两种loss scale设置方式:
- 静态scale (Constant Loss Scale),是指训练过程中用固定的 扩放因子进行scale,扩放因子的设置需要经验性调整超参。
- 动态scale (Dynamic Loss Scale),随着训练推进,每间隔一定迭代 或 满足一定触发条件,就对用于loss scale的因子 进行扩放 或 缩放。简言之:框架帮你动态地找最合适的scale超参,减少了超参数调整的复杂性,使得训练过程更加自动化。
参见megatron/arguments.py
部分:
group.add_argument('--loss-scale', type=float, default=None,
help='Static loss scaling, positive power of 2 '
'values can improve fp16 convergence. If None, dynamic'
'loss scaling is used.')
group.add_argument('--initial-loss-scale', type=float, default=2**32,
help='Initial loss-scale for dynamic loss scaling.')
group.add_argument('--min-loss-scale', type=float, default=1.0,
help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--loss-scale-window', type=float, default=1000,
help='Window over which to raise/lower dynamic scale.')
group.add_argument('--hysteresis', type=int, default=2,
help='hysteresis for dynamic loss scaling')
--loss-scale
:设置静态的loss-scale值,必须为2的正指数幂,缺省默认动态。--initial-loss-scale
: 动态调整loss-scale时,需要设置的loss-scale初始值,缺省$2^{32}$。—-min-loss-scale
:动态调整loss-scale时,loss-scale的最小值,缺省1.0。--loss-scale-window
:动态调整loss-scale时,loss-scale的扩放周期,缺省1000 iters。--hysteresis
:动态调整loss-scale时,loss-scale的缩放周期,或者称(缩放前)nan/if的容忍次数,缺省值是2。
此处注意:
- 在不设置任何loss-scale相关参数的情况下,在分布式优化器或者混合精度训练中,默认采用动态loss-scale,且从极高的loss-scale初值开始 不断向下寻找最合适的scale参数。
- 可以通过设置静态loss-scale值为1.0来避免混精训练中的缩放。(特殊情形下可用)
loss-scale
设置后,会优先采用静态loss-scale。initial-loss-scale
动态loss-scale初值,笔者一般经验性地设置为 $32768 (2^{15})$ 或 $65536(2^{16})$ 。
1.2 类中初始化
上述开关值,在 megatron/optimizer/init.py
中被调用:
def get_megatron_optimizer(...):
args = get_args()
...
if args.fp16 or args.bf16 or args.use_distributed_optimizer:
grad_scaler = None
# Constant loss scale.
if args.loss_scale:
grad_scaler = ConstantGradScaler(args.loss_scale)
# Dynamic loss scale.
else:
if args.fp16:
grad_scaler = DynamicGradScaler(
initial_scale=args.initial_loss_scale,
min_scale=args.min_loss_scale,
growth_factor=2.0,
backoff_factor=0.5,
growth_interval=args.loss_scale_window,
hysteresis=args.hysteresis)
# Megatron optimizer.
opt_ty = DistributedOptimizer \
if args.use_distributed_optimizer else \
Float16OptimizerWithFloat16Params
return opt_ty(optimizer,
args.clip_grad,...,grad_scaler,model)
只有开启distributed_optimizer
或者 fp16/bf16(混精)
模式下,才会使用loss scale:
- 先检查是否使用静态scale。 (混精、分布式优化器都可采用)
- 其他情况下,一律采用动态scale。(只有混精训练可以启用)
2. scale值管理
上述loss scale相关参数,被用于初始化 grad_scaler
,最终传递给优化器。GradScaler的具体实现封装在 megatron/optimizer/grad_scaler.py
中:
2.1 MegatronGradScaler (Abstract)
class MegatronGradScaler(ABC):
def __init__(self, initial_scale):
"""Initialize scale value with the input initial scale."""
assert initial_scale > 0.0
# # modified for cpu use
#self._scale = torch.cuda.FloatTensor([initial_scale])
self._scale = torch.FloatTensor([initial_scale])
@property
def scale(self):
return self._scale
@property
def inv_scale(self):
return self._scale.double().reciprocal().float()
@abstractmethod
def update(self, found_inf):
pass
@abstractmethod
def state_dict(self):
pass
@abstractmethod
def load_state_dict(self, state_dict):
pass
上述的MegatronGradScaler是一个抽象基类,定义了基本的接口和一些初始化操作。
其中_scale
是一个protected变量(约定 只允许通过类方法来访问和修改),保存当前的scale值(也是运行时的_scale值 );scale
和inv_scale
属性,分别实现了取scale值和scale值倒数,用于对loss的缩放和还原。
update
、state_dict
和load_state_dict
三个抽象方法主要是为了兼容动态缩放。
2.2 Constant Grad Scaler
class ConstantGradScaler(MegatronGradScaler):
def update(self, found_inf):
pass
def state_dict(self):
return dict()
def load_state_dict(self, state_dict):
pass
静态scale非常简单,不需要进行scale值的更新,通过--loss-scale
参数在类对象初始化方法(继承父类)中设置_scale
值。
2.3 Dynamic Grad Scaler
接下来我们来看动态缩放器,分别分析init、state_dict
和update
相关。
class DynamicGradScaler(MegatronGradScaler):
def __init__(self, initial_scale, min_scale,
growth_factor, backoff_factor,
growth_interval, hysteresis):
""""Grad scaler with dynamic scale that gets adjusted
during training."""
super(DynamicGradScaler, self).__init__(initial_scale)
# Lower bound on the scale.
assert min_scale > 0.0
assert min_scale <= initial_scale
# # modified for cpu use
#self.min_scale = torch.cuda.FloatTensor([min_scale])
self.min_scale = torch.FloatTensor([min_scale])
# Growth and backoff factors for the scale.
assert growth_factor > 1.0
#self.growth_factor = torch.cuda.FloatTensor([growth_factor])
self.growth_factor = torch.FloatTensor([growth_factor])
assert backoff_factor < 1.0
assert backoff_factor > 0.0
#self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
self.backoff_factor = torch.FloatTensor([backoff_factor])
# Interval over which if we don't see any inf/nan,
# we will scale the grad scale by the growth factor.
assert growth_interval > 0
self.growth_interval = growth_interval
# Number of inf/nans we should see before scaling down
# the grad scale by the backoff factor.
assert hysteresis > 0
self.hysteresis = hysteresis
# Trackers.
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
初始化部分没什么可以分析的,一系列assert和规范操作。我们重点来看下相关成员变量:
- 用户参数赋值:
initial_scale
:初始scale值,范围(0, +OO),由--initial-loss-scale
设定,默认$2^{32}$。min_scale
:最小scale值,范围(0,init_scale],由--min-loss-scale
设定,默认是1.0。growth_interval
:scale的扩放周期,范围(0, +OO),由--loss-scale-window
设定,默认1000迭代。hysteresis
:scale的缩放周期,范围(0,max_value_INT), 由--hysteresis
设定,默认是2。
- hard-code赋值(1.2部分):
- growth_factor :scale值的扩放因子,范围(1.0, +OO),hard-code设定值2.0。
- backoff_factor:scale值的缩放因子,范围(0.0, 1.0),hard-code设定值0.5。
- 本类的protected :
_grow_tracker
:扩放周期计数器,初值设定为0。_hysteresis_tracker
:缩放周期计数器,初值设定为self.hysteresis
。
state_dict部分:
def state_dict(self):
state_dict = {}
state_dict['scale'] = self._scale
state_dict['growth_tracker'] = self._growth_tracker
state_dict['hysteresis_tracker'] = self._hysteresis_tracker
return state_dict
def load_state_dict(self, state_dict):
# # modified for cpu use
#self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
self._scale = state_dict['scale']
self._growth_tracker = state_dict['growth_tracker']
self._hysteresis_tracker = state_dict['hysteresis_tracker']
scale_dict 需要额外保存 当前的_scale
值,扩放周期计数器(growth_tracker
)和缩放计数器(_hysteresis_tracker
)。
scale值动态更新(或者说搜索、自适应)过程:
def update(self, found_inf):
# If we have an inf/nan, growth tracker is set to 0
# and hysterisis tracker is reduced by 1.
if found_inf:
self._growth_tracker = 0
self._hysteresis_tracker -= 1
# Now if we are out of hysteresis count, scale down the loss.
if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale)
else:
# If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1
# If we have had enough consequitive intervals with no nan/inf:
if self._growth_tracker == self.growth_interval:
# Reset the tracker and hysteresis trackers,
self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis
# and scale up the loss scale.
self._scale = self._scale * self.growth_factor
_scale
值的自适应过程,分为缩放和扩放两种情况:
- 缩放:如果发现梯度中存在 inf或者nan,将扩放计数器
_growth_tracker
重置(清零),_hysteresis_tracker
减1,若_hysteresis_tracker
小于等于 0,则对_scale
本身进行缩放(减半)。这里之所以是小于等于,是因为触发了一次缩放后(变残血了),此后每遭遇一次inf或者nan,就进行一次缩放。直到扩放一次scale值,才会回满血。 - 扩放:在未发现inf或者nan时,累加扩放计数器
_growth_tracker
,当达到扩放周期时,_hysteresis_tracker
重置(回满血),进行一次_scale
的扩放(倍增)。
简而言之:
- 在上一次scale值动态调整后,当连续
growth_interval
(1000) iter内都没有出现nan或者inf时,倍增(x2)扩放一次scale,立刻重置缩放和扩放计数器。 - 在上一次scale值扩放调整后,(不连续)累积
_hysteresis
(2)次 出现nan或者inf,倍减(x0.5)缩放一次scale。 在上一次scale值缩放调整后,每出现一次nan或者inf,就进行一次scale倍减。
3. scale的应用
本节详进一步析下优化器内grad_scale如何发挥作用,我们参考megatron/optimizer/optimizer.py
,主要来看FP16混精模式下的优化器。
类继承关系:
ABC
|— MegatronOptimizer
|— FP32Optimizer
|— MixedPrecisionOptimizer
|— Float16OptimizerWithFloat16Params
3.1 对loss进行scale
MegatronOptimizer中实现了最简单扩放过程,就是前向计算完成,得到loss后,将其值乘上当前维护的_scale
值,接着依靠反向传播的链式反应,自动将扩放依次作用到对各个参数的梯度计算上。
class MegatronOptimizer(ABC):
...
@abstractmethod
def get_loss_scale(self):
"""The output should be a cuda tensor of size 1."""
pass
# 缩放
def scale_loss(self, loss):
"""Simple scaling."""
return self.get_loss_scale * loss
在megatron/training.py的 train_step()
中,传入了优化器的scale_loss
函数。
pretrain → train → train_step → get_forward_backward_func
def train_step(forward_step_func,...,optimizer,...)
forward_backward_func = get_forward_backward_func()
...
losses_reduced = forward_backward_func(
forward_step_func=forward_step_func,
...
grad_scaler=optimizer.scale_loss,
...)
...
继续往下找,我们在megatron/core/pipeline_parallel/schedules.py
中找到了最终调用的grad_scaler
这一参数的是 backward_step()
函数。
forward_backward_func→get_forward_backward_func → forward_backward_no_pipelining→ backward_step
def backward_step(grad_scaler,...,output_tensor,...,):
...
# Backward pass.
if output_tensor_grad[0] is None and grad_scaler is not None:
output_tensor = grad_scaler(output_tensor[0])
去看forward_step
不难发现,这里的output_tensor[0]
就是前向结束后计算得到的loss值。这和我们对loss-scale的认知是符合的:即使用scale值扩放过的loss进行反向计算。
3.2 unscale 操作
unscale故名思义,就是对loss的缩放过程。我们扩放的目的是为了避免反向计算过程中梯度过小而丢失的问题,但在最终更新参数时,应将梯度还原回去,还原过程定义在MixedPrecisionOptimizer
类中。
class MixedPrecisionOptimizer(MegatronOptimizer):
def __init__(...,fp16,...,grad_scaler,...):
...
#形式上 fp16必须使用grad_scaler,可以不缩放,但必须有
self.fp16 = fp16
self.grad_scaler = grad_scaler
if self.grad_scaler is None:
assert not self.fp16, 'fp16 expects a grad scalar'
if self.grad_scaler:
self.found_inf = torch.FloatTensor([0.0]) #初始化 found_inf标志
...
if self.grad_scaler is None:
self._scalar_one = torch.FloatTensor([1.0])
#获取当前的loss_scaler值
def get_loss_scaler(self):
if self.grad_scaler is None:
return self._scaler_one
else:
return self.grad_scaler.scale
...
def _unscale_main_grads_and_check_for_nan(self):
#收集 main grads
main_grads = self._collect_main_grad_data_for_unscaling()
# 重置 found_inf标志
self.found_inf.fill(0.0)
# 还原(Unscale)梯度scale 并设置found_inf标志
torch._amp_foreach_non_finite_check_and_unscale_(
main_grads, self.found_inf, self.grad_scaler.inv_scale)
# all-reduce 更新found_inf标志,统计梯度出现inf或nan的情况
torch.distributed.all_reduce(self.found_inf, op=torch.distributed.Reduce.OP.MAX, group=self.get_model_parallel_groups())
# 返回结果 (是否出现了nan或inf)
return (self.found_inf.item() > 0)
#装饰器 可理解为冻结梯度的更改
@torch.no_grad()
def step(self, args, timers):
...
if self.grad_scaler:
# unscale grad and check for inf or nan
timier('optimizer-unscale-and-check-inf',log_level=1).start(barrier=args.barrier_with_L1_time)
found_inf_flag = _unscale_main_grads_and_check_for_nan()
timer('optimizer-unscale-and-check-inf',log_level=1).stop()
# update grad scaler
self.grad_scaler.update(found_inf_flag)
...
关于缩放还原,我们应该参考AMP的函数,由于笔者Foucus的平台不是CUDA,读者可自行研读CUDA的amp代码。
其本质是: 统计各参数的梯度值中存在的inf或者nan数量,并对梯度值进行还原。
(下面的内容可略过)在笔者工作的平台上,上述函数的调用关系是:
<pytorch>/aten/src/ATen/native/cpu/AmpKernels.cpp
`amp_foreach_non_finite_check_and_unscale_cpu`
|— <pytorch>/third_party/swtensor/master/amp.cpp
amp_foreach_non_finite_check_and_unscale
|— third_party/swtensor/slave/amp_slave.cpp
spawn__local_amp_foreach_non_finite_check_and_unscale
而上述函数又是实例化了对应数据类型的模版函数实现的,主核代码如下:
template < typename FType > ssize_t amp_foreach_non_finite_check_and_unscale_impl(const Tensor & self, size_t num_of_elm, float * found_inf_ptr, float inv_scale) {
args_amp_foreach_non_finite_check_and_unscale < FType > args;
args.input = static_cast < FType * > (self.data());
//args.found_inf_ptr = found_inf_ptr;
size_t cg_num = spawn_proxy_n_cg * 64;
float result[cg_num] = {
0
};
args.output = result;
args.inv_scale = inv_scale;
args.len = num_of_elm;
__real_spawn_proxy_run_mpecg((void * )(spawn__local_amp_foreach_non_finite_check_and_unscale < FType > ), & args, spawn_proxy_n_cg);
spawn_proxy_join_mpecg(spawn_proxy_n_cg);
for (int i = 0; i < cg_num; i++) {
if (result[i] == 1. f) {
* found_inf_ptr = 1. f;
break;
}
}
return 0;
}
其中的args结构体定义如下:
template < typename FType > struct args_amp_foreach_non_finite_check_and_unscale {
FType * input;
float inv_scale;
size_t len;
float * output;
};
template < typename FType > void spawn__local_amp_foreach_non_finite_check_and_unscale(args_amp_foreach_non_finite_check_and_unscale < FType > * );
包含了输入、输出和对应长度,以及scale倒数。
从核代码(third_party/swtensor/slave/amp_slave.cpp
)如下:
template < typename FType > void spawn__local_amp_foreach_non_finite_check_and_unscale(
args_amp_foreach_non_finite_check_and_unscale < FType > * args
) {
using type_t = slave_type_t < FType > ;
init_slave_type < FType > ();
// element number in each cg size_t buf_size = LDM_ASIZE / sizeof(type_t);
// buffer to save the batch element of the Tensor type_t buf [buf_size] = { 0 };
// cg number size_t nthreads = spawn_proxy_n_cg * 64;
// float * found_inf_ptr = args->found_inf_ptr;
float * output = static_cast < float * > (args -> output);
float inv_scale = args -> inv_scale;
size_t total_num = args -> len;
size_t left_num = total_num;
// computer number each iteration size_t cpt_num = nthreads * buf_size;
// computer times: (x + n -1) / n size_t cpt_times = (total_num + cpt_num - 1) / cpt_num;
type_t * input = reinterpret_cast < type_t * > (args -> input);
ssize_t tid = CRTS_cgn * 64 + CRTS_tid;
float flag = 0;
for (
size_t k = 0; k < cpt_times; k + +
) {
left_num = total_num - k * cpt_num;
size_t batch_size = buf_size;
if (left_num < cpt_num) {
batch_size = (left_num + nthreads - 1) / nthreads;
}
size_t left_size = 0;
if (left_num > tid * batch_size) {
left_size = batch_size;
if (left_num <= (tid + 1) * batch_size) {
left_size = left_num - tid * batch_size;
}
CRTS_dma_get(
buf,
input + k * cpt_num + tid * batch_size,
left_size * sizeof(type_t)
);
for (
size_t j = 0; j < left_size; j + +
) { // zijunx: ? ? // if (
( * found_inf_ptr != 1. f) && (
std::isinf(static_cast < float > (buf[j])) || std::isnan(static_cast < float > (buf[j]))
)
) if (
(flag != 1. f) && (
std::isinf(static_cast < float > (buf[j])) || std::isnan(static_cast < float > (buf[j]))
)
) { // * found_inf_ptr = 1.f;
flag = 1. f;
}
buf[j] = (inv_scale == 1. f ? buf[j] : buf[j] * inv_scale);
}
}
CRTS_dma_put(
input + k * cpt_num + tid * batch_size,
buf,
left_size * sizeof(type_t)
);
}
CRTS_dma_put(output + tid, & flag, sizeof(float));
}
3.3 State dict
state_dict主要方便打印以及checkpoint载入。
def Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
...
def state_dict(self):
state_dict={}
...
if self.grad_scaler:
state_dict['grad_scaler'] = self.grad_scaler.state_dict()
...
def load_dict(self,state_dict):
...
if 'grad_scaler' not in state_dict:
print_rank_0('***WARNING*** found an old checkpoint, will not '
'load grad scaler ..')
else:
if self.grad_scaler:
self.grad_scaler.load_dict(state_dict['grad_scaler'])
else:
print_rank_0('***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...')