请选择 进入手机版 | 继续访问电脑版
查看: 1438|回复: 1

Pytorch的BatchNorm层使用中容易出现的问题

[复制链接]

665

主题

1234

帖子

6561

积分

xdtech

Rank: 5Rank: 5

积分
6561
发表于 2020-5-14 08:53:23 | 显示全部楼层 |阅读模式
Batch Normalization,批规范化
Batch Normalization(简称为BN)[2],中文翻译成批规范化,是在深度学习中普遍使用的一种技术,通常用于解决多层神经网络中间层的协方差偏移(Internal Covariate Shift)问题,类似于网络输入进行零均值化和方差归一化的操作,不过是在中间层的输入中操作而已,具体原理不累述了,见[2-4]的描述即可。

在BN操作中,最重要的无非是这四个式子:
Input:Output:更新过程:μBσ2BxˆiyiB={x1,⋯,xm},为m个样本组成的一个batch数据。需要学习到的是γ和β,在框架中一般表述成weight和bias。←1m∑i=1mxi    //得到batch中的统计特性之一:均值←1m∑i=1m(xi−μB)2    //得到batch中的另一个统计特性:方差←xi−μBσ2B+ϵ−−−−−−√    //规范化,其中ϵ是一个很小的数,防止计算出现数值问题。←γxˆi+β≡BNγ,β(xi)    //这一步是输出尺寸伸缩和偏移。\begin{aligned}\mathbf{Input}: & \mathcal{B}=\{x_1,\cdots,x_m\},为m个样本组成的一个batch数据 。\\\mathbf{Output}: & 需要学习到的是 \gamma和\beta,在框架中一般表述成\mathrm{weight}和\mathrm{bias}。\\更新过程: & \\ \mu_{\mathcal{B}} & \leftarrow \frac{1}{m} \sum_{i=1}^m x_i \ \ \ \ // 得到batch中的统计特性之一:均值 \\\sigma_{\mathcal{B}}^2 &\leftarrow \frac{1}{m} \sum_{i=1}^m (x_i - \mu_{\mathcal{B}})^2 \ \ \ \ // 得到batch中的另一个统计特性:方差 \\\hat{x}_i & \leftarrow \dfrac{x_i-\mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2+\epsilon}} \ \ \ \ \\&// 规范化,其中\epsilon是一个很小的数,防止计算出现数值问题。\\y_i &\leftarrow \gamma \hat{x}_i+\beta \equiv \mathrm{BN}_{\gamma, \beta}(x_i) \ \ \ \ //这一步是输出尺寸伸缩和偏移。\end{aligned}
Input:
Output:
更新过程:
μ
B
​       

σ
B
2
​       

x
^

i
​       

y
i
​       

​       

B={x
1
​       
,⋯,x
m
​       
},为m个样本组成的一个batch数据。
需要学习到的是γ和β,在框架中一般表述成weight和bias。

m
1
​       

i=1

m
​       
x
i
​       
     //得到batch中的统计特性之一:均值

m
1
​       

i=1

m
​       
(x
i
​       
−μ
B
​       
)
2
     //得到batch中的另一个统计特性:方差

σ
B
2
​       

​       

x
i
​       
−μ
B
​       

​       

//规范化,其中ϵ是一个很小的数,防止计算出现数值问题。
←γ
x
^

i
​       
+β≡BN
γ,β
​       
(x
i
​       
)    //这一步是输出尺寸伸缩和偏移。
​       


注意到这里的最后一步也称之为仿射(affine),引入这一步的目的主要是设计一个通道,使得输出output至少能够回到输入input的状态(当γ=1,β=0\gamma=1,\beta=0γ=1,β=0时)使得BN的引入至少不至于降低模型的表现,这是深度网络设计的一个套路。
整个过程见流程图,BN在输入后插入,BN的输出作为规范后的结果输入的后层网络中。

forward
backward
forward
backward
input batch
Batch_Norm
Output batch
好了,这里我们记住了,在BN中,一共有这四个参数我们要考虑的:

γ,β\gamma, \betaγ,β:分别是仿射中的weight\mathrm{weight}weight和bias\mathrm{bias}bias,在pytorch中用weight和bias表示。
μB\mu_{\mathcal{B}}μ
B
​       
和σ2B\sigma_{\mathcal{B}}^2σ
B
2
​       
:和上面的参数不同,这两个是根据输入的batch的统计特性计算的,严格来说不算是“学习”到的参数,不过对于整个计算是很重要的。在pytorch中,这两个统计参数,用running_mean和running_var表示[5],这里的running指的就是当前的统计参数不一定只是由当前输入的batch决定,还可能和历史输入的batch有关,详情见以下的讨论,特别是参数momentum那部分。
Update 2020/3/16:
因为BN层的考核,在工作面试中实在是太常见了,在本文顺带补充下BN层的参数的具体shape大小。
以图片输入作为例子,在pytorch中即是nn.BatchNorm2d(),我们实际中的BN层一般是对于通道进行的,举个例子而言,我们现在的输入特征(可以视为之前讨论的batch中的其中一个样本的shape)为x∈RC×W×H\mathbf{x} \in \mathbb{R}^{C \times W \times H}x∈R
C×W×H
(其中C是通道数,W是width,H是height),那么我们的μB∈RC\mu_{\mathcal{B}} \in \mathbb{R}^{C}μ
B
​       
∈R
C
,而方差σ2B∈RC\sigma^{2}_{\mathcal{B}} \in \mathbb{R}^Cσ
B
2
​       
∈R
C
。而仿射中weight,γ∈RC\mathrm{weight}, \gamma \in \mathbb{R}^{C}weight,γ∈R
C
以及bias,β∈RC\mathrm{bias}, \beta \in \mathbb{R}^{C}bias,β∈R
C
。我们会发现,这些参数,无论是学习参数还是统计参数都会通道数有关,其实在pytorch中,通道数的另一个称呼是num_features,也即是特征数量,因为不同通道的特征信息通常很不相同,因此需要隔离开通道进行处理。

有些朋友可能会认为这里的weight应该是一个张量,而不应该是一个矢量,其实不是的,这里的weight其实应该看成是 对输入特征图的每个通道得到的归一化后的xˆ\hat{\mathbf{x}}
x
^
进行尺度放缩的结果,因此对于一个通道数为CCC的输入特征图,那么每个通道都需要一个尺度放缩因子,同理,bias也是对于每个通道而言的。这里切勿认为 yi←γxˆi+βy_i \leftarrow \gamma \hat{x}_i+\betay
i
​       
←γ
x
^

i
​       
+β这一步是一个全连接层,他其实只是一个尺度放缩而已。关于这些参数的形状,其实可以直接从pytorch源代码看出,这里截取了_NormBase层的部分初始代码,便可一见端倪。

class _NormBase(Module):
    """Common base of _InstanceNorm and _BatchNorm"""
    _version = 2
    __constants__ = ['track_running_stats', 'momentum', 'eps',
                     'num_features', 'affine']

    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True,
                 track_running_stats=True):
        super(_NormBase, self).__init__()
        self.num_features = num_features
        self.eps = eps
        self.momentum = momentum
        self.affine = affine
        self.track_running_stats = track_running_stats
        if self.affine:
            self.weight = Parameter(torch.Tensor(num_features))
            self.bias = Parameter(torch.Tensor(num_features))
        else:
            self.register_parameter('weight', None)
            self.register_parameter('bias', None)
        if self.track_running_stats:
            self.register_buffer('running_mean', torch.zeros(num_features))
            self.register_buffer('running_var', torch.ones(num_features))
            self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))
        else:
            self.register_parameter('running_mean', None)
            self.register_parameter('running_var', None)
            self.register_parameter('num_batches_tracked', None)
        self.reset_parameters()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
在Pytorch中使用
Pytorch中的BatchNorm的API主要有:

torch.nn.BatchNorm1d(num_features,
                     eps=1e-05,
                     momentum=0.1,
                     affine=True,
                     track_running_stats=True)
1
2
3
4
5
一般来说pytorch中的模型都是继承nn.Module类的,都有一个属性trainning指定是否是训练状态,训练状态与否将会影响到某些层的参数是否是固定的,比如BN层或者Dropout层。通常用model.train()指定当前模型model为训练状态,model.eval()指定当前模型为测试状态。
同时,BN的API中有几个参数需要比较关心的,一个是affine指定是否需要仿射,还有个是track_running_stats指定是否跟踪当前batch的统计特性。容易出现问题也正好是这三个参数:trainning,affine,track_running_stats。

其中的affine指定是否需要仿射,也就是是否需要上面算式的第四个,如果affine=False,则γ=1,β=0\gamma=1,\beta=0γ=1,β=0,并且不能学习被更新。一般都会设置成affine=True[10]
trainning和track_running_stats,track_running_stats=True表示跟踪整个训练过程中的batch的统计特性,得到方差和均值,而不只是仅仅依赖与当前输入的batch的统计特性。相反的,如果track_running_stats=False那么就只是计算当前输入的batch的统计特性中的均值和方差了。当在推理阶段的时候,如果track_running_stats=False,此时如果batch_size比较小,那么其统计特性就会和全局统计特性有着较大偏差,可能导致糟糕的效果。
一般来说,trainning和track_running_stats有四种组合[7]

trainning=True, track_running_stats=True。这个是期望中的训练阶段的设置,此时BN将会跟踪整个训练过程中batch的统计特性。
trainning=True, track_running_stats=False。此时BN只会计算当前输入的训练batch的统计特性,可能没法很好地描述全局的数据统计特性。
trainning=False, track_running_stats=True。这个是期望中的测试阶段的设置,此时BN会用之前训练好的模型中的(假设已经保存下了)running_mean和running_var并且不会对其进行更新。一般来说,只需要设置model.eval()其中model中含有BN层,即可实现这个功能。[6,8]
trainning=False, track_running_stats=False 效果同(2),只不过是位于测试状态,这个一般不采用,这个只是用测试输入的batch的统计特性,容易造成统计特性的偏移,导致糟糕效果。
同时,我们要注意到,BN层中的running_mean和running_var的更新是在forward()操作中进行的,而不是optimizer.step()中进行的,因此如果处于训练状态,就算你不进行手动step(),BN的统计特性也会变化的。如

model.train() # 处于训练状态

for data, label in self.dataloader:
        pred = model(data)  
        # 在这里就会更新model中的BN的统计特性参数,running_mean, running_var
        loss = self.loss(pred, label)
        # 就算不要下列三行代码,BN的统计特性参数也会变化
        opt.zero_grad()
        loss.backward()
        opt.step()
1
2
3
4
5
6
7
8
9
10
这个时候要将model.eval()转到测试阶段,才能固定住running_mean和running_var。有时候如果是先预训练模型然后加载模型,重新跑测试的时候结果不同,有一点性能上的损失,这个时候十有八九是trainning和track_running_stats设置的不对,这里需要多注意。 [8]

假设一个场景,如下图所示:

input
model_A
model_B
output
此时为了收敛容易控制,先预训练好模型model_A,并且model_A内含有若干BN层,后续需要将model_A作为一个inference推理模型和model_B联合训练,此时就希望model_A中的BN的统计特性值running_mean和running_var不会乱变化,因此就必须将model_A.eval()设置到测试模式,否则在trainning模式下,就算是不去更新该模型的参数,其BN都会改变的,这个将会导致和预期不同的结果。

Update 2020/3/17:
评论区的Oshrin朋友提出问题

作者您好,写的很好,但是是否存在问题。即使将track_running_stats设置为False,如果momentum不为None的话,还是会用滑动平均来计算running_mean和running_var的,而非是仅仅使用本batch的数据情况。而且关于冻结bn层,有一些更好的方法。

这里的momentum的作用,按照文档,这个参数是在对统计参数进行更新过程中,进行指数平滑使用的,比如统计参数的更新策略将会变成:
xˆnew=(1−momentum)×xˆ+momentum×xt\hat{x}_{\mathrm{new}} = (1-\mathrm{momentum}) \times \hat{x} + \mathrm{momentum} \times x_t
x
^

new
​       
=(1−momentum)×
x
^
+momentum×x
t
​       


其中的更新后的统计参数xˆnew\hat{x}_{\mathrm{new}}
x
^

new
​       
,是根据当前观察xtx_tx
t
​       
和历史观察xˆ\hat{x}
x
^
进行加权平均得到的(差分的加权平均相当于历史序列的指数平滑),默认的momentum=0.1。然而跟踪历史信息并且更新的这个行为是基于track_running_stats为true并且training=true的情况同时成立的时候,才会进行的,当在track_running_stats=true, training=false时(在默认的model.eval()情况下,即是之前谈到的四种组合的第三个,既满足这种情况),将不涉及到统计参数的指数滑动更新了。[12,13]

这里引用一个不错的BN层冻结的例子,如:[14]

import torch
import torch.nn as nn
from torch.nn import init
from torchvision import models
from torch.autograd import Variable
from apex.fp16_utils import *

def fix_bn(m):
    classname = m.__class__.__name__
    if classname.find('BatchNorm') != -1:
        m.eval()

model = models.resnet50(pretrained=True)
model.cuda()
model = network(model)
model.train()
model.apply(fix_bn) # fix batchnorm
input = Variable(torch.FloatTensor(8, 3, 224, 224).cuda())
output = model(input)
output_mean = torch.mean(output)
output_mean.backward()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
总结来说,在某些情况下,即便整体的模型处于model.train()的状态,但是某些BN层也可能需要按照需求设置为model_bn.eval()的状态。


回复

使用道具 举报

665

主题

1234

帖子

6561

积分

xdtech

Rank: 5Rank: 5

积分
6561
 楼主| 发表于 2020-5-14 08:53:32 | 显示全部楼层
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

快速回复 返回顶部 返回列表