Megatron-LM中的loss-scale

loss-scale被广泛用于混精训练中,扩大反向传播过程中的参数梯度计算。笔者进一步解读了Megatron-LM框架中的loss-scale设置到应用的完整过程,希望能加深理解。

1. loss-scale初始化

1.1 超参数初始化

Megatron中支持静态和动态两种loss scale设置方式:

  • 静态scale (Constant Loss Scale),是指训练过程中用固定的 扩放因子进行scale,扩放因子的设置需要经验性调整超参
  • 动态scale (Dynamic Loss Scale),随着训练推进,每间隔一定迭代 或 满足一定触发条件,就对用于loss scale的因子 进行扩放 或 缩放。简言之:框架帮你动态地找最合适的scale超参,减少了超参数调整的复杂性,使得训练过程更加自动化。
    参见 megatron/arguments.py 部分:
group.add_argument('--loss-scale', type=float, default=None,
                       help='Static loss scaling, positive power of 2 '
                       'values can improve fp16 convergence. If None, dynamic'
                     'loss scaling is used.')
group.add_argument('--initial-loss-scale', type=float, default=2**32,
                   help='Initial loss-scale for dynamic loss scaling.')
group.add_argument('--min-loss-scale', type=float, default=1.0,
                   help='Minimum loss scale for dynamic loss scale.')
group.add_argument('--loss-scale-window', type=float, default=1000,
                   help='Window over which to raise/lower dynamic scale.')
group.add_argument('--hysteresis', type=int, default=2,
                       help='hysteresis for dynamic loss scaling')
  • --loss-scale:设置静态的loss-scale值,必须为2的正指数幂,缺省默认动态。
  • --initial-loss-scale : 动态调整loss-scale时,需要设置的loss-scale初始值,缺省$2^{32}$。
  • —-min-loss-scale :动态调整loss-scale时,loss-scale的最小值,缺省1.0。
  • --loss-scale-window:动态调整loss-scale时,loss-scale的扩放周期,缺省1000 iters。
  • --hysteresis :动态调整loss-scale时,loss-scale的缩放周期,或者称(缩放前)nan/if的容忍次数,缺省值是2。

此处注意:

  • 在不设置任何loss-scale相关参数的情况下,在分布式优化器或者混合精度训练中,默认采用动态loss-scale,且从极高的loss-scale初值开始 不断向下寻找最合适的scale参数
  • 可以通过设置静态loss-scale值为1.0来避免混精训练中的缩放。(特殊情形下可用)
  • loss-scale 设置后,会优先采用静态loss-scale。
  • initial-loss-scale 动态loss-scale初值,笔者一般经验性地设置为 $32768 (2^{15})$ 或 $65536(2^{16})$ 。

1.2 类中初始化

上述开关值,在 megatron/optimizer/init.py 中被调用:

def get_megatron_optimizer(...):
	args = get_args()
	...
	if args.fp16 or args.bf16 or args.use_distributed_optimizer:
        grad_scaler = None
        # Constant loss scale.
        if args.loss_scale:
            grad_scaler = ConstantGradScaler(args.loss_scale)
        # Dynamic loss scale.
        else:
            if args.fp16:
                grad_scaler = DynamicGradScaler(
                    initial_scale=args.initial_loss_scale,
                    min_scale=args.min_loss_scale,
                    growth_factor=2.0,
                    backoff_factor=0.5,
                    growth_interval=args.loss_scale_window,
                    hysteresis=args.hysteresis)

        # Megatron optimizer.
        opt_ty = DistributedOptimizer \
            if args.use_distributed_optimizer else \
            Float16OptimizerWithFloat16Params
        return opt_ty(optimizer,
                      args.clip_grad,...,grad_scaler,model)

只有开启distributed_optimizer 或者 fp16/bf16(混精)模式下,才会使用loss scale:

  • 先检查是否使用静态scale。 (混精、分布式优化器都可采用)
  • 其他情况下,一律采用动态scale。(只有混精训练可以启用)

2. scale值管理

上述loss scale相关参数,被用于初始化 grad_scaler,最终传递给优化器。GradScaler的具体实现封装在 megatron/optimizer/grad_scaler.py 中:

2.1 MegatronGradScaler (Abstract)

class MegatronGradScaler(ABC):

    def __init__(self, initial_scale):
        """Initialize scale value with the input initial scale."""
        assert initial_scale > 0.0
        # # modified for cpu use
        #self._scale = torch.cuda.FloatTensor([initial_scale])
        self._scale = torch.FloatTensor([initial_scale])

    @property
    def scale(self):
        return self._scale

    @property
    def inv_scale(self):
        return self._scale.double().reciprocal().float()

    @abstractmethod
    def update(self, found_inf):
        pass

    @abstractmethod
    def state_dict(self):
        pass

    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

上述的MegatronGradScaler是一个抽象基类,定义了基本的接口和一些初始化操作。

其中_scale是一个protected变量(约定 只允许通过类方法来访问和修改),保存当前的scale值(也是运行时的_scale值 );scaleinv_scale属性,分别实现了取scale值和scale值倒数,用于对loss的缩放还原

updatestate_dictload_state_dict 三个抽象方法主要是为了兼容动态缩放。

2.2 Constant Grad Scaler

class ConstantGradScaler(MegatronGradScaler):

    def update(self, found_inf):
        pass

    def state_dict(self):
        return dict()

    def load_state_dict(self, state_dict):
        pass

静态scale非常简单,不需要进行scale值的更新,通过--loss-scale 参数在类对象初始化方法(继承父类)中设置_scale值。

2.3 Dynamic Grad Scaler

接下来我们来看动态缩放器,分别分析init、state_dictupdate相关。

class DynamicGradScaler(MegatronGradScaler):

    def __init__(self, initial_scale, min_scale,
                 growth_factor, backoff_factor,
                 growth_interval, hysteresis):
        """"Grad scaler with dynamic scale that gets adjusted
        during training."""
        super(DynamicGradScaler, self).__init__(initial_scale)

        # Lower bound on the scale.
        assert min_scale > 0.0
        assert min_scale <= initial_scale
        # # modified for cpu use
        #self.min_scale = torch.cuda.FloatTensor([min_scale])
        self.min_scale = torch.FloatTensor([min_scale])
        # Growth and backoff factors for the scale.
        assert growth_factor > 1.0
        #self.growth_factor = torch.cuda.FloatTensor([growth_factor])
        self.growth_factor = torch.FloatTensor([growth_factor])
        assert backoff_factor < 1.0
        assert backoff_factor > 0.0
        #self.backoff_factor = torch.cuda.FloatTensor([backoff_factor])
        self.backoff_factor = torch.FloatTensor([backoff_factor])
        # Interval over which if we don't see any inf/nan,
        # we will scale the grad scale by the growth factor.
        assert growth_interval > 0
        self.growth_interval = growth_interval
        # Number of inf/nans we should see before scaling down
        # the grad scale by the backoff factor.
        assert hysteresis > 0
        self.hysteresis = hysteresis

        # Trackers.
        self._growth_tracker = 0
        self._hysteresis_tracker = self.hysteresis

初始化部分没什么可以分析的,一系列assert和规范操作。我们重点来看下相关成员变量:

  • 用户参数赋值:
    • initial_scale:初始scale值,范围(0, +OO),由--initial-loss-scale设定,默认$2^{32}$。
    • min_scale:最小scale值,范围(0,init_scale],由--min-loss-scale 设定,默认是1.0。
    • growth_interval:scale的扩放周期,范围(0, +OO),由--loss-scale-window设定,默认1000迭代。
    • hysteresis:scale的缩放周期,范围(0,max_value_INT), 由--hysteresis 设定,默认是2。
  • hard-code赋值(1.2部分):
    • growth_factor :scale值的扩放因子,范围(1.0, +OO),hard-code设定值2.0。
    • backoff_factor:scale值的缩放因子,范围(0.0, 1.0),hard-code设定值0.5。
  • 本类的protected :
    • _grow_tracker :扩放周期计数器,初值设定为0。
    • _hysteresis_tracker :缩放周期计数器,初值设定为self.hysteresis

state_dict部分:

def state_dict(self):
	state_dict = {}
	state_dict['scale'] = self._scale
	state_dict['growth_tracker'] = self._growth_tracker
	state_dict['hysteresis_tracker'] = self._hysteresis_tracker
	return state_dict


def load_state_dict(self, state_dict):
	# # modified for cpu use
	#self._scale = state_dict['scale'].cuda(torch.cuda.current_device())
	self._scale = state_dict['scale']
	self._growth_tracker = state_dict['growth_tracker']
	self._hysteresis_tracker = state_dict['hysteresis_tracker']

scale_dict 需要额外保存 当前的_scale值,扩放周期计数器(growth_tracker)和缩放计数器(_hysteresis_tracker)。

scale值动态更新(或者说搜索、自适应)过程:

def update(self, found_inf):

	# If we have an inf/nan, growth tracker is set to 0
	# and hysterisis tracker is reduced by 1.
	if found_inf:
		self._growth_tracker = 0
		self._hysteresis_tracker -= 1
		# Now if we are out of hysteresis count, scale down the loss.
		if self._hysteresis_tracker <= 0:
			self._scale = torch.max(self._scale * self.backoff_factor,
									self.min_scale)
	else:
		# If there is no nan/inf, increment the growth tracker.
		self._growth_tracker += 1
		# If we have had enough consequitive intervals with no nan/inf:
		if self._growth_tracker == self.growth_interval:
			# Reset the tracker and hysteresis trackers,
			self._growth_tracker = 0
			self._hysteresis_tracker = self.hysteresis
			# and scale up the loss scale.
			self._scale = self._scale * self.growth_factor

_scale 值的自适应过程,分为缩放和扩放两种情况:

  • 缩放:如果发现梯度中存在 inf或者nan,将扩放计数器_growth_tracker重置(清零),_hysteresis_tracker 减1,若_hysteresis_tracker 小于等于 0,则对_scale本身进行缩放(减半)。这里之所以是小于等于,是因为触发了一次缩放后(变残血了),此后每遭遇一次inf或者nan,就进行一次缩放。直到扩放一次scale值,才会回满血。
  • 扩放:在未发现inf或者nan时,累加扩放计数器_growth_tracker ,当达到扩放周期时,_hysteresis_tracker 重置(回满血),进行一次_scale的扩放(倍增)。

简而言之:

  • 在上一次scale值动态调整后,当连续growth_interval (1000) iter内都没有出现nan或者inf时,倍增(x2)扩放一次scale,立刻重置缩放和扩放计数器。
  • 在上一次scale值扩放调整后,(不连续)累积_hysteresis (2)次 出现nan或者inf,倍减(x0.5)缩放一次scale。 在上一次scale值缩放调整后,每出现一次nan或者inf,就进行一次scale倍减。

3. scale的应用

本节详进一步析下优化器内grad_scale如何发挥作用,我们参考megatron/optimizer/optimizer.py ,主要来看FP16混精模式下的优化器。

类继承关系:

ABC
|— MegatronOptimizer         
	 |— FP32Optimizer
	 |— MixedPrecisionOptimizer 
		|— Float16OptimizerWithFloat16Params 

3.1 对loss进行scale

MegatronOptimizer中实现了最简单扩放过程,就是前向计算完成,得到loss后,将其值乘上当前维护的_scale值,接着依靠反向传播的链式反应,自动将扩放依次作用到对各个参数的梯度计算上。

class MegatronOptimizer(ABC):
	...
	@abstractmethod
    def get_loss_scale(self):
        """The output should be a cuda tensor of size 1."""
        pass

		# 缩放
    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale * loss

megatron/training.py的 train_step()中,传入了优化器的scale_loss函数。

pretrain → train → train_step → get_forward_backward_func

def train_step(forward_step_func,...,optimizer,...)
	forward_backward_func = get_forward_backward_func()
	    ...
	    losses_reduced = forward_backward_func(
	        forward_step_func=forward_step_func,
					...
	        grad_scaler=optimizer.scale_loss,
					...)
	...

继续往下找,我们在megatron/core/pipeline_parallel/schedules.py 中找到了最终调用的grad_scaler这一参数的是 backward_step()函数。

forward_backward_func→get_forward_backward_func → forward_backward_no_pipelining→ backward_step

def backward_step(grad_scaler,...,output_tensor,...,):
	...
	# Backward pass.
    if output_tensor_grad[0] is None and grad_scaler is not None:
        output_tensor = grad_scaler(output_tensor[0])

去看forward_step不难发现,这里的output_tensor[0] 就是前向结束后计算得到的loss值。这和我们对loss-scale的认知是符合的:即使用scale值扩放过的loss进行反向计算。

3.2 unscale 操作

unscale故名思义,就是对loss的缩放过程。我们扩放的目的是为了避免反向计算过程中梯度过小而丢失的问题,但在最终更新参数时,应将梯度还原回去,还原过程定义在MixedPrecisionOptimizer类中。

class MixedPrecisionOptimizer(MegatronOptimizer):
	def __init__(...,fp16,...,grad_scaler,...):
		...
		#形式上 fp16必须使用grad_scaler,可以不缩放,但必须有
		self.fp16 = fp16
		self.grad_scaler = grad_scaler 
		if self.grad_scaler is None:
			assert not self.fp16, 'fp16 expects a grad scalar'
		if self.grad_scaler: 
			self.found_inf = torch.FloatTensor([0.0])  #初始化 found_inf标志
		...
		if self.grad_scaler is None:
			self._scalar_one = torch.FloatTensor([1.0])
	
	#获取当前的loss_scaler值
	def get_loss_scaler(self):
		if self.grad_scaler is None:
			return self._scaler_one 
		else:
			return self.grad_scaler.scale
	...
	
	def _unscale_main_grads_and_check_for_nan(self):
		#收集 main grads
		main_grads = self._collect_main_grad_data_for_unscaling()
		# 重置 found_inf标志
		self.found_inf.fill(0.0)
		# 还原(Unscale)梯度scale 并设置found_inf标志
		torch._amp_foreach_non_finite_check_and_unscale_(
				main_grads, self.found_inf, self.grad_scaler.inv_scale)
		# all-reduce 更新found_inf标志,统计梯度出现inf或nan的情况
		torch.distributed.all_reduce(self.found_inf, 																op=torch.distributed.Reduce.OP.MAX,												group=self.get_model_parallel_groups())
		# 返回结果 (是否出现了nan或inf)
		return (self.found_inf.item() > 0)

	#装饰器 可理解为冻结梯度的更改
	@torch.no_grad()
	def step(self, args, timers):
		...
		if self.grad_scaler: 
			# unscale grad and check for inf or nan
			timier('optimizer-unscale-and-check-inf',log_level=1).start(barrier=args.barrier_with_L1_time)
			found_inf_flag = _unscale_main_grads_and_check_for_nan()
			timer('optimizer-unscale-and-check-inf',log_level=1).stop()
			# update grad scaler
			self.grad_scaler.update(found_inf_flag)
			...

关于缩放还原,我们应该参考AMP的函数,由于笔者Foucus的平台不是CUDA,读者可自行研读CUDA的amp代码。
其本质是: 统计各参数的梯度值中存在的inf或者nan数量,并对梯度值进行还原。

(下面的内容可略过)在笔者工作的平台上,上述函数的调用关系是:

<pytorch>/aten/src/ATen/native/cpu/AmpKernels.cpp
   `amp_foreach_non_finite_check_and_unscale_cpu`

	|— <pytorch>/third_party/swtensor/master/amp.cpp
	    amp_foreach_non_finite_check_and_unscale
   
		|— third_party/swtensor/slave/amp_slave.cpp
		    spawn__local_amp_foreach_non_finite_check_and_unscale

而上述函数又是实例化了对应数据类型的模版函数实现的,主核代码如下:

template < typename FType > ssize_t amp_foreach_non_finite_check_and_unscale_impl(const Tensor & self, size_t num_of_elm, float * found_inf_ptr, float inv_scale) {
    args_amp_foreach_non_finite_check_and_unscale < FType > args;
    args.input = static_cast < FType * > (self.data());
    //args.found_inf_ptr = found_inf_ptr;
    size_t cg_num = spawn_proxy_n_cg * 64;
    float result[cg_num] = {
        0
    };
    args.output = result;
    args.inv_scale = inv_scale;
    args.len = num_of_elm;
    __real_spawn_proxy_run_mpecg((void * )(spawn__local_amp_foreach_non_finite_check_and_unscale < FType > ), & args, spawn_proxy_n_cg);
    spawn_proxy_join_mpecg(spawn_proxy_n_cg);
    for (int i = 0; i < cg_num; i++) {
        if (result[i] == 1. f) {
            * found_inf_ptr = 1. f;
            break;
        }
    }
    return 0;
}

其中的args结构体定义如下:

template < typename FType > struct args_amp_foreach_non_finite_check_and_unscale {
    FType * input;
    float inv_scale;
    size_t len;
    float * output;
};
template < typename FType > void spawn__local_amp_foreach_non_finite_check_and_unscale(args_amp_foreach_non_finite_check_and_unscale < FType > * );

包含了输入、输出和对应长度,以及scale倒数。

从核代码(third_party/swtensor/slave/amp_slave.cpp)如下:

template < typename FType > void spawn__local_amp_foreach_non_finite_check_and_unscale(
    args_amp_foreach_non_finite_check_and_unscale < FType > * args
) {
    using type_t = slave_type_t < FType > ;
    init_slave_type < FType > ();
    // element number in each cg size_t buf_size = LDM_ASIZE / sizeof(type_t);
    // buffer to save the batch element of the Tensor type_t buf [buf_size] = { 0 };
    // cg number size_t nthreads = spawn_proxy_n_cg * 64;
    // float * found_inf_ptr = args->found_inf_ptr;
    float * output = static_cast < float * > (args -> output);
    float inv_scale = args -> inv_scale;
    size_t total_num = args -> len;
    size_t left_num = total_num;
    // computer number each iteration size_t cpt_num = nthreads * buf_size;
    // computer times: (x + n -1) / n size_t cpt_times = (total_num + cpt_num - 1) / cpt_num;
    type_t * input = reinterpret_cast < type_t * > (args -> input);
    ssize_t tid = CRTS_cgn * 64 + CRTS_tid;
    float flag = 0;
    for (
        size_t k = 0; k < cpt_times; k + +
    ) {
        left_num = total_num - k * cpt_num;
        size_t batch_size = buf_size;
        if (left_num < cpt_num) {
            batch_size = (left_num + nthreads - 1) / nthreads;
        }
        size_t left_size = 0;
        if (left_num > tid * batch_size) {
            left_size = batch_size;
            if (left_num <= (tid + 1) * batch_size) {
                left_size = left_num - tid * batch_size;
            }
            CRTS_dma_get(
                buf,
                input + k * cpt_num + tid * batch_size,
                left_size * sizeof(type_t)
            );
            for (
                size_t j = 0; j < left_size; j + +
            ) { // zijunx: ? ? // if (
                ( * found_inf_ptr != 1. f) && (
                    std::isinf(static_cast < float > (buf[j])) || std::isnan(static_cast < float > (buf[j]))
                )
            ) if (
                (flag != 1. f) && (
                    std::isinf(static_cast < float > (buf[j])) || std::isnan(static_cast < float > (buf[j]))
                )
            ) { // * found_inf_ptr = 1.f;
                flag = 1. f;
            }
            buf[j] = (inv_scale == 1. f ? buf[j] : buf[j] * inv_scale);
        }
    }
    CRTS_dma_put(
        input + k * cpt_num + tid * batch_size,
        buf,
        left_size * sizeof(type_t)
    );
}
CRTS_dma_put(output + tid, & flag, sizeof(float));
}

3.3 State dict

state_dict主要方便打印以及checkpoint载入。

def Float16OptimizerWithFloat16Params(MixedPrecisionOptimizer):
	...
	def state_dict(self):
		state_dict={}
		...
		if self.grad_scaler:
			state_dict['grad_scaler'] = self.grad_scaler.state_dict()
		...
	
	def load_dict(self,state_dict):
		...
		if 'grad_scaler' not in state_dict:
			print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ..')
		else:
			if self.grad_scaler:			
				self.grad_scaler.load_dict(state_dict['grad_scaler'])
			else:
				print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

Megatron-LM中的loss-scale
http://example.com/posts/fc0334c8/
发布于
2024年6月5日
许可协议