资讯详情

自然语言处理入门——RNN架构解析

介绍自然语言处理

RNN架构解析

认识RNN模型

  • RNN:中文称循环神经网络,一般以序列数据为输入,通过网络内部结构设计有效捕捉序列之间的关系特征,一般以序列形式输出。

  • RNN单层网络结构:

请添加图片描述

  • 以时间步对RNN单层网络结构展开:(这看起来像和CNN比较像了)

  • RNN循环机制可以作为当前时间步输入的一部分。

  • 因为RNN结构可以很好地利用序列之间的关系,因此可以很好地处理人类语言、语音等自然连续输入,广泛应用于NLP文本分类、情感分析、意图识别、机器翻译等各个领域的任务。

  • 假设用户输入What time is it? 下面是RNN处理方法:

  • 最终输出O对用户意图进行处理分析。

  • seq2seq由于其输入输出不限,架构最早被提出用于机器翻译,现在也是应用最广泛的RNN模型结构在机器翻译、阅读理解、文本摘要等多个领域都有很多应用实践。

传统RNN模型

  • 以中间的方块为例,它的输入有两部分,分别是h(t-1)和x(t),代表上一时间步的隐层输出和此时间步的输入。RNN结构结构后,会融合在一起,实际上是拼接形成新的张量[x(t), h(t-1)],通过全连接层(线性层)使用此张量tanh双曲正切作为激活函数,最终得到时间步的输出h(t),它将作为下一步的输入和x(t 1)一起进入结构体。

  • h t = t a n h ( W t [ X t , h t ? 1 ] b t ) h_t = tanh(W_t[X_t, h_{t-1}] b_t) ht=tanh(Wt[Xt​,ht−1​]+bt​)

  • 激活函数tanh:用于帮助调节流经神经网络的值,tanh函数将值压缩在-1和1之间。

pytorch中RNN的应用:

  • RNN类在torch.nn.RNN中,其初始化主要参数解释:

    • input_size: The number of expected features in the input `x`. 输入张量x中特征维度大小
    • hidden_size: The number of features in the hidden state `h`. 隐层张量h中特征维度大小
    • num_layers: Number of recurrent layers. 隐含层数量
    • nonlinearity: The non-linearity to use. Can be either 'tanh' or 'relu'. Default: 'tanh'
  • RNN类实例化对象主要参数解释:

    • : 输入张量x
    • : 初始化的隐层张量h
import torch
import torch.nn as nn

rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output)
print(output.shape)

print(hn)
print(hn.shape)

output:
tensor([[[-0.1823,  0.0859, -0.3405, -0.5000, -0.6306,  0.5065,  0.8202,
          -0.2139,  0.5886, -0.1549,  0.3035,  0.3669, -0.3702, -0.0026,
           0.0604,  0.1055, -0.0163, -0.2904,  0.2216,  0.0020],
         [ 0.2791, -0.1390,  0.3652,  0.0539, -0.5179,  0.7433, -0.0418,
           0.8043,  0.5498, -0.0131,  0.4987,  0.8964,  0.0033,  0.1708,
          -0.0594, -0.0106, -0.5742,  0.5557, -0.3524, -0.3199],
         [-0.0448,  0.2398, -0.1254,  0.5049,  0.6504,  0.6963,  0.5681,
           0.5640,  0.2442, -0.6644,  0.2833,  0.7397,  0.0966,  0.4050,
           0.5397,  0.1153, -0.5372, -0.4970,  0.0586, -0.1714]],

        [[ 0.5013,  0.1563,  0.1514,  0.1719, -0.3103, -0.4294, -0.6875,
           0.0665, -0.4604,  0.1708,  0.1925,  0.0077, -0.2452, -0.1904,
           0.4462,  0.0012, -0.2967, -0.5996, -0.0416,  0.0766],
         [ 0.6177,  0.4556,  0.3853,  0.2834,  0.3121, -0.1427, -0.4408,
          -0.1028, -0.7400,  0.2298, -0.5990,  0.4145,  0.1973,  0.1061,
           0.3418, -0.1150, -0.2209, -0.5048,  0.0269,  0.4954],
         [ 0.1053,  0.3156,  0.2890,  0.2079,  0.0477,  0.2353, -0.0389,
          -0.0014, -0.2171, -0.0972,  0.0658,  0.4972,  0.2478,  0.0355,
           0.4458,  0.0405, -0.5211, -0.2562, -0.1064,  0.5259]],

        [[-0.4234,  0.1803,  0.1560,  0.4580, -0.2345, -0.2388,  0.3107,
          -0.0058,  0.0634,  0.0977, -0.4543,  0.0582, -0.0860,  0.2199,
           0.1864, -0.5531,  0.5284, -0.2800, -0.0510,  0.0912],
         [-0.5156,  0.4014,  0.0628,  0.3032, -0.0117, -0.1661,  0.5899,
           0.1559,  0.2996, -0.4454,  0.0348,  0.0651, -0.5742,  0.2271,
           0.1080, -0.3659,  0.6118, -0.4189,  0.0549, -0.0393],
         [-0.0868,  0.5991,  0.1813,  0.5599,  0.3917, -0.3454,  0.0961,
           0.0566,  0.0284, -0.3377,  0.0170, -0.1184, -0.5352,  0.3805,
           0.1599, -0.1647,  0.2100, -0.5550,  0.1266,  0.2302]],

        [[-0.3534,  0.1374,  0.1209,  0.0387, -0.1049, -0.2417,  0.1742,
          -0.2224,  0.5119, -0.6369,  0.3746,  0.4883, -0.1907,  0.3288,
           0.1200,  0.0569,  0.0759, -0.1567,  0.3188,  0.2419],
         [ 0.0486,  0.3413,  0.1351,  0.5912, -0.3284, -0.3300, -0.0787,
           0.1665,  0.1738, -0.2786,  0.3029,  0.0880,  0.3581, -0.0811,
           0.4021, -0.3304, -0.2823, -0.2832,  0.1019,  0.3242],
         [-0.1608,  0.1246,  0.0863,  0.3260, -0.2099,  0.2095,  0.4521,
           0.4346,  0.3898, -0.4924,  0.2472,  0.2306,  0.3713, -0.0955,
           0.3075, -0.2875, -0.1641, -0.2343, -0.1563,  0.3321]],

        [[ 0.0825,  0.1601,  0.3947,  0.4077,  0.2677, -0.4913, -0.3839,
          -0.1458, -0.0360,  0.2327,  0.0738,  0.2608,  0.3539, -0.0649,
           0.2555, -0.2991, -0.2076, -0.2232, -0.2227,  0.1667],
         [ 0.1509,  0.4328,  0.1422,  0.5717, -0.1864, -0.0670,  0.1437,
           0.2307, -0.2267,  0.0637, -0.2936,  0.2159,  0.4971,  0.0890,
           0.2068, -0.5263,  0.3438, -0.2558,  0.1983,  0.4059],
         [ 0.2452,  0.6272,  0.1001,  0.2857,  0.2405, -0.3052, -0.0981,
           0.1137, -0.1352, -0.3110,  0.1155,  0.0432, -0.0898, -0.0357,
           0.1601, -0.0056, -0.1636, -0.4776,  0.1329,  0.3925]]],
       grad_fn=<StackBackward>)
torch.Size([5, 3, 20])

hn:
tensor([[[ 0.4038,  0.0203,  0.0424,  0.3863, -0.0233, -0.4767, -0.2328,
           0.6005, -0.1970, -0.1546,  0.1492, -0.4493, -0.8081,  0.7578,
          -0.6790,  0.1135, -0.1818, -0.0088, -0.2488, -0.1144],
         [-0.1782, -0.4026,  0.4985, -0.5878,  0.4833,  0.5260,  0.0710,
           0.5741,  0.3153, -0.0460,  0.1516,  0.3593, -0.7491,  0.4448,
          -0.6297,  0.2588,  0.4649, -0.2219, -0.5977,  0.3895],
         [-0.1429, -0.0431,  0.7642, -0.3143,  0.4679,  0.1455,  0.3204,
          -0.2070, -0.1016, -0.4045, -0.3219, -0.2693, -0.6370, -0.0010,
          -0.5872, -0.5141,  0.0144, -0.4947,  0.5004, -0.5219]],

        [[ 0.0825,  0.1601,  0.3947,  0.4077,  0.2677, -0.4913, -0.3839,
          -0.1458, -0.0360,  0.2327,  0.0738,  0.2608,  0.3539, -0.0649,
           0.2555, -0.2991, -0.2076, -0.2232, -0.2227,  0.1667],
         [ 0.1509,  0.4328,  0.1422,  0.5717, -0.1864, -0.0670,  0.1437,
           0.2307, -0.2267,  0.0637, -0.2936,  0.2159,  0.4971,  0.0890,
           0.2068, -0.5263,  0.3438, -0.2558,  0.1983,  0.4059],
         [ 0.2452,  0.6272,  0.1001,  0.2857,  0.2405, -0.3052, -0.0981,
           0.1137, -0.1352, -0.3110,  0.1155,  0.0432, -0.0898, -0.0357,
           0.1601, -0.0056, -0.1636, -0.4776,  0.1329,  0.3925]]],
       grad_fn=<StackBackward>)
torch.Size([2, 3, 20])

传统RNN模型的优缺点

  • 优势:

    • 由于内部结构简单,对计算资源要求低,相比之后学的RNN的变体参数总量少了很多,在任务上性能、效果都表现优异。
  • 劣势:

    • 传统RNN在解决长序列之间的关联时,通过实践证明其表现很差,原因是在进行反向传播时,过长的序列导致梯度计算异常,发生梯度消失或爆炸。
  • 什么是梯度消失或爆炸:

    • 根据反向传播算法及链式法则,梯度的计算可以简化为以下公式

    • D n = σ ′ ( z 1 ) w 1 ⋅ σ ′ ( z 2 ) w 2 ⋅ . . . ⋅ σ ′ ( z n ) w n D_n = σ'(z_1)w_1 · σ'(z_2)w_2 · ... ·σ'(z_n)w_n Dn​=σ′(z1​)w1​⋅σ′(z2​)w2​⋅...⋅σ′(zn​)wn​

    • 其中sigmoid的导数值是固定的,在[0, 0.25]之间,而一旦公式中w也小于1,那么通过这样的公式连乘之后,最终的梯度会变得非常非常小,这种现象称作梯度消失,反正,如果认为增大w的值,使其大于1,最终可能造成梯度过大,称作梯度爆炸。

    • 如果梯度消失,权重将无法被更新,最终导致训练失败。梯度爆炸带来的梯度过大,大幅更新网络参数,在极端情况下结果会溢出(NaN)

LSTM模型

  • LSTM也称长短时记忆结构,是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。同时LSTM的结构更复杂,它的结构可以分为四个部分解析:
    • 遗忘门
    • 输入门
    • 细胞状态
    • 输出

标签: 4595连接器3917连接器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

锐单商城 - 一站式电子元器件采购平台