资讯详情

深度学习与计算机视觉教程(3) | 损失函数与最优化(CV通关指南·完结)

ShowMeAI研究中心

  • 作者:韩信子@ShowMeAI
  • 教程地址:http://www.showmeai.tech/tutorials/37
  • 本文地址:http://www.showmeai.tech/article-detail/262
  • 声明:所有版权,请联系平台和作者,注明来源

本系列为 计算机视觉深度学习(Deep Learning for Computer Vision)》全套学习笔记,相应的课程视频可以在 查看。获取更多信息的方法见文末。


引言

在上一篇 在内容中,我们介绍了线性分类器。我们希望线性分类器能够准确地对图像进行分类,并有一套优化其权重参数的方法。这就是本文ShowMeAI介绍损失函数和最优化相关知识。

本篇重点

  • 损失函数
  • 数据损失和正则损失
  • SVM 损失
  • Softmax损失
  • 优化策略
  • 梯度计算方法
  • 梯度下降

1.线性分类:损失函数

1.1 损失函数的概念

回到之前解释过的小猫分类示例,这个例子中的权重值 W W W 很差,因为猫的分数很低(-96.8),而狗(437.9)和船(61.95)比较高。

我们定义(Loss Function)(有时也叫) L L L 衡量预测结果「不满意程度」。评分函数输出结果与真实结果的差异越大,损失函数越大,反之亦然。

对于有 N N N 训练样本对应 N N N 训练集数据的标签 ( x i , y i ) (x_{i},y_{i}) (xi,yi)),损失函数定义为:

L = 1 N ∑ i = 1 N L i ( f ( x i , W ) , y i ) L=\frac{1}{N} \sum_{i=1}^NL_i(f(x_i,W), y_i) L=N1​i=1∑N​Li​(f(xi​,W),yi​)

  • 即每个样本损失函数求和取平均。目标就是找到一个合适的 W W W 使 L L L 最小。
  • :真正的损失函数 L L L 还有一项正则损失 R ( W ) R(W) R(W),下面会有说明。

损失函数有很多种,下面介绍最常见的一些。

1.2 多类支持向量机损失 (Multiclass Support Vector Machine Loss)

SVM 的知识可以参考ShowMeAI的中的文章,多类 SVM 可以看作二分类 SVM 的一个推广,它可以把样本数据分为多个类别。

1) 数据损失(data loss)

SVM 的损失函数想要 SVM 在正确分类上的得分始终比不正确分类上的得分高出一个边界值 Δ \Delta Δ。

我们先看一条数据样本(一张图片)上的损失函数 L i L_i Li​ 如何定义,根据之前的描述,第 i i i 个数据 ( x i , y i ) (x_{i},y_{i}) (xi​,yi​) )中包含图像 x i x_i xi​ 的像素和代表正确类别的标签 y i y_i yi​。给评分函数输入像素数据,然后通过公式 f ( x i , W ) f(x_i, W) f(xi​,W) )来计算不同分类类别的分值。

这里我们将所有分值存放到 s s s 中,第 j j j 个类别的得分就是 s s s 的第 j j j 个元素: s j = f ( x i , W j ) s_j = f(x_i, W_j) sj​=f(xi​,Wj​)。针对第 i i i 条数据样本的多类 SVM 的损失函数定义如下:

L i = ∑ j ≠ y i max ⁡ ( 0 , s j − s y i + Δ ) L_i = \sum_{j\neq y_i} \max(0, s_j - s_{y_i} + \Delta) Li​=j​=yi​∑​max(0,sj​−syi​​+Δ)

直观来看,就是如果评分函数给真实标签的分数比其他某个标签的分数高出 Δ \Delta Δ,则对该其他标签的损失为 0 0 0;否则损失就是 s j − s y i + Δ s_j - s_{y_i}+ \Delta sj​−syi​​+Δ。要对所有不正确的分类循环一遍。

下面用一个示例来解释一下:

简化计算起见,我们只使用3个训练样本,对应3个类别的分类, y i = 0 , 1 , 2 y_i =0,1,2 yi​=0,1,2 对于第1张图片 「小猫」 来说,评分 s = [ 3.2 , 5.1 , − 1.7 ] s=[3.2, 5.1, -1.7] s=[3.2,5.1,−1.7] 其中 s y i = 3.2 s_{y_i}=3.2 syi​​=3.2 如果把 Δ \Delta Δ 设为 1 1 1,则针对小猫的损失函数:

L 1 = m a x ( 0 , 5.1 − 3.2 + 1 ) + m a x ( 0 , − 1.7 − 3.2 + 1 ) = m a x ( 0 , 2.9 ) + m a x ( 0 , − 3.9 ) = 2.9 + 0 = 2.9 L_1 = max(0, 5.1 - 3.2 + 1) +max(0, -1.7 - 3.2 + 1) = max(0, 2.9) + max(0, -3.9) = 2.9 + 0 =2.9 L1​=max(0,5.1−3.2+1)+max(0,−1.7−3.2+1)=max(0,2.9)+max(0,−3.9)=2.9+0=2.9

同理可得 L 2 = 0 L_2 =0 L2​=0, L 3 = 12.9 L_3 =12.9 L3​=12.9,所以对整个训练集的损失: L = ( 2.9 + 0 + 12.9 ) / 3 = 5.27 L= (2.9 + 0 + 12.9)/3 =5.27 L=(2.9+0+12.9)/3=5.27。

上面可以看到 SVM 的损失函数不仅想要正确分类类别 y i y_i yi​ 的分数比不正确类别分数高,而且至少要高 Δ \Delta Δ。如果不满足这点,就开始计算损失值。

:之所以会加入一个 Δ \Delta Δ,是为了真实标签的分数比错误标签的分数高出一定的距离,如上图所示,如果其他分类分数进入了红色的区域,甚至更高,那么就开始计算损失;如果没有这些情况,损失值为 0 0 0:

  • 损失最小是 0 0 0,最大无穷;
  • 如果求和的时候,不加 j ≠ y i j\neq y_i j​=yi​ 这一条件, L L L 会加 Δ \Delta Δ;
  • 计算 L i L_i Li​ 时使用平均不用求和,只会缩放 L L L 不会影响好坏;而如果使用平方,就会打破平衡,会使坏的更坏, L L L 受到影响。

在训练最开始的时候,往往会给 W W W 一个比较小的初值,结果就是 s s s 中所有值都很小接近于 0 0 0,此时的损失 L L L 应该等于分类类别数 K − 1 K-1 K−1,这里是 2 2 2。可根据这个判断代码是否有问题;

非向量化和向量化多类 SVM 损失代码实现如下:

def L_i(x, y, W):
  """
  非向量化版本。
  计算单个例子(x,y)的多类 SVM 损失    
  - x 是表示图像的列向量(例如,CIFAR-10中的3073 x 1),附加偏置维度
  - y 是一个给出正确类索引的整数(例如,CIFAR-10中的0到9之间)    
  - W 是权重矩阵(例如,CIFAR-10中的10 x 3073)  """
  delta = 1.0 # 间隔 delta
  scores = W.dot(x) # 得分数组,10 x 1
  correct_class_score = scores[y]
  D = W.shape[0] # 分类的总数,即为10
  loss_i = 0.0
  for j in range(D): # 迭代所有错误分类   
    if j == y:
      # 跳过正确分类的
      continue
    # 第 i 个样本累加损失
    loss_i += max(0, scores[j] - correct_class_score + delta)
  return loss_i

def L_i_vectorized(x, y, W):
  '''
  更快的半向量化实现。
  half-vectorized指的是这样一个事实:对于单个样本,实现不包含for循环,
  但是在样本外仍然有一个循环(在此函数之外)
  '''
  delta = 1.0
  scores = W.dot(x)
  # 用一个向量操作计算和所有类别的间隔
  margins = np.maximum(0, scores - scores[y] + delta)
  # y处的值应该为0  
  margins[y] = 0
  loss_i = np.sum(margins)
  return loss_i

这里的评分函数 f ( x i ; W ) = W x i f(x_i; W) = W x_i f(xi​;W)=Wxi​,所以损失函数可以写为:

L i = ∑ j ≠ y i max ⁡ ( 0 , w j T x i − w y i T x i + Δ ) L_i = \sum_{j\neq y_i} \max(0, w_j^T x_i - w_{y_i}^T x_i + \Delta) Li​=j​=yi​∑​max(0,wjT​xi​−w

标签: fz系列无源交流电流隔离变送器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台