介绍自然语言处理
RNN架构解析
认识RNN模型
-
RNN:中文称循环神经网络,一般以序列数据为输入,通过网络内部结构设计有效捕捉序列之间的关系特征,一般以序列形式输出。
-
RNN单层网络结构:
- 以时间步对RNN单层网络结构展开:(这看起来像和CNN比较像了)
-
RNN循环机制可以作为当前时间步输入的一部分。
-
因为RNN结构可以很好地利用序列之间的关系,因此可以很好地处理人类语言、语音等自然连续输入,广泛应用于NLP文本分类、情感分析、意图识别、机器翻译等各个领域的任务。
-
假设用户输入What time is it? 下面是RNN处理方法:
- 最终输出O对用户意图进行处理分析。
- RNN模型分类:
- 分类输入输出结构:
- N vs N - RNN
- N vs 1 - RNN
- 1 vs N - RNN
- N vs M - RNN
- 从RNN内部结构分类:
- 传统RNN
- LSTM
- Bi-LSTM
- GRU
- Bi-GRU
- N vs N - RNN:输入和输出等长,一般用于生成等长的诗歌。
- N vs 1 - RNN:要求输出是一个单独的值,只要在最后一个隐层的输出h上进行线性变换就可以了。大部分情况下为了明确结果,还要使用sigmoid或softmax这种结构经常用于文本分类。
- 1 vs N - RNN:每次输出时,唯一的输入可以用来生成图片的文本任务。
- N vs M - RNN:这是一种无限输入输出长度RNN结构,由编码器和解码器两部分组成,两部分内部结构都是某类RNN,也被称为seq2seq首先通过编码器输入数据,最后输出隐含变量c,然后在解码器解码的每个时间步骤上使用输入信息的有效利用。
- seq2seq由于其输入输出不限,架构最早被提出用于机器翻译,现在也是应用最广泛的RNN模型结构在机器翻译、阅读理解、文本摘要等多个领域都有很多应用实践。
传统RNN模型
- 以中间的方块为例,它的输入有两部分,分别是h(t-1)和x(t),代表上一时间步的隐层输出和此时间步的输入。RNN结构结构后,会融合在一起,实际上是拼接形成新的张量
[x(t), h(t-1)]
,通过全连接层(线性层)使用此张量tanh双曲正切作为激活函数,最终得到时间步的输出h(t),它将作为下一步的输入和x(t 1)一起进入结构体。
-
h t = t a n h ( W t [ X t , h t ? 1 ] b t ) h_t = tanh(W_t[X_t, h_{t-1}] b_t) ht=tanh(Wt[Xt,ht−1]+bt)
-
激活函数tanh:用于帮助调节流经神经网络的值,tanh函数将值压缩在-1和1之间。
pytorch中RNN的应用:
-
RNN类在torch.nn.RNN中,其初始化主要参数解释:
- input_size: The number of expected features in the input `x`. 输入张量x中特征维度大小
- hidden_size: The number of features in the hidden state `h`. 隐层张量h中特征维度大小
- num_layers: Number of recurrent layers. 隐含层数量
- nonlinearity: The non-linearity to use. Can be either
'tanh'
or'relu'
. Default:'tanh'
-
RNN类实例化对象主要参数解释:
- : 输入张量x
- : 初始化的隐层张量h
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 2)
input = torch.randn(5, 3, 10)
h0 = torch.randn(2, 3, 20)
output, hn = rnn(input, h0)
print(output)
print(output.shape)
print(hn)
print(hn.shape)
output:
tensor([[[-0.1823, 0.0859, -0.3405, -0.5000, -0.6306, 0.5065, 0.8202,
-0.2139, 0.5886, -0.1549, 0.3035, 0.3669, -0.3702, -0.0026,
0.0604, 0.1055, -0.0163, -0.2904, 0.2216, 0.0020],
[ 0.2791, -0.1390, 0.3652, 0.0539, -0.5179, 0.7433, -0.0418,
0.8043, 0.5498, -0.0131, 0.4987, 0.8964, 0.0033, 0.1708,
-0.0594, -0.0106, -0.5742, 0.5557, -0.3524, -0.3199],
[-0.0448, 0.2398, -0.1254, 0.5049, 0.6504, 0.6963, 0.5681,
0.5640, 0.2442, -0.6644, 0.2833, 0.7397, 0.0966, 0.4050,
0.5397, 0.1153, -0.5372, -0.4970, 0.0586, -0.1714]],
[[ 0.5013, 0.1563, 0.1514, 0.1719, -0.3103, -0.4294, -0.6875,
0.0665, -0.4604, 0.1708, 0.1925, 0.0077, -0.2452, -0.1904,
0.4462, 0.0012, -0.2967, -0.5996, -0.0416, 0.0766],
[ 0.6177, 0.4556, 0.3853, 0.2834, 0.3121, -0.1427, -0.4408,
-0.1028, -0.7400, 0.2298, -0.5990, 0.4145, 0.1973, 0.1061,
0.3418, -0.1150, -0.2209, -0.5048, 0.0269, 0.4954],
[ 0.1053, 0.3156, 0.2890, 0.2079, 0.0477, 0.2353, -0.0389,
-0.0014, -0.2171, -0.0972, 0.0658, 0.4972, 0.2478, 0.0355,
0.4458, 0.0405, -0.5211, -0.2562, -0.1064, 0.5259]],
[[-0.4234, 0.1803, 0.1560, 0.4580, -0.2345, -0.2388, 0.3107,
-0.0058, 0.0634, 0.0977, -0.4543, 0.0582, -0.0860, 0.2199,
0.1864, -0.5531, 0.5284, -0.2800, -0.0510, 0.0912],
[-0.5156, 0.4014, 0.0628, 0.3032, -0.0117, -0.1661, 0.5899,
0.1559, 0.2996, -0.4454, 0.0348, 0.0651, -0.5742, 0.2271,
0.1080, -0.3659, 0.6118, -0.4189, 0.0549, -0.0393],
[-0.0868, 0.5991, 0.1813, 0.5599, 0.3917, -0.3454, 0.0961,
0.0566, 0.0284, -0.3377, 0.0170, -0.1184, -0.5352, 0.3805,
0.1599, -0.1647, 0.2100, -0.5550, 0.1266, 0.2302]],
[[-0.3534, 0.1374, 0.1209, 0.0387, -0.1049, -0.2417, 0.1742,
-0.2224, 0.5119, -0.6369, 0.3746, 0.4883, -0.1907, 0.3288,
0.1200, 0.0569, 0.0759, -0.1567, 0.3188, 0.2419],
[ 0.0486, 0.3413, 0.1351, 0.5912, -0.3284, -0.3300, -0.0787,
0.1665, 0.1738, -0.2786, 0.3029, 0.0880, 0.3581, -0.0811,
0.4021, -0.3304, -0.2823, -0.2832, 0.1019, 0.3242],
[-0.1608, 0.1246, 0.0863, 0.3260, -0.2099, 0.2095, 0.4521,
0.4346, 0.3898, -0.4924, 0.2472, 0.2306, 0.3713, -0.0955,
0.3075, -0.2875, -0.1641, -0.2343, -0.1563, 0.3321]],
[[ 0.0825, 0.1601, 0.3947, 0.4077, 0.2677, -0.4913, -0.3839,
-0.1458, -0.0360, 0.2327, 0.0738, 0.2608, 0.3539, -0.0649,
0.2555, -0.2991, -0.2076, -0.2232, -0.2227, 0.1667],
[ 0.1509, 0.4328, 0.1422, 0.5717, -0.1864, -0.0670, 0.1437,
0.2307, -0.2267, 0.0637, -0.2936, 0.2159, 0.4971, 0.0890,
0.2068, -0.5263, 0.3438, -0.2558, 0.1983, 0.4059],
[ 0.2452, 0.6272, 0.1001, 0.2857, 0.2405, -0.3052, -0.0981,
0.1137, -0.1352, -0.3110, 0.1155, 0.0432, -0.0898, -0.0357,
0.1601, -0.0056, -0.1636, -0.4776, 0.1329, 0.3925]]],
grad_fn=<StackBackward>)
torch.Size([5, 3, 20])
hn:
tensor([[[ 0.4038, 0.0203, 0.0424, 0.3863, -0.0233, -0.4767, -0.2328,
0.6005, -0.1970, -0.1546, 0.1492, -0.4493, -0.8081, 0.7578,
-0.6790, 0.1135, -0.1818, -0.0088, -0.2488, -0.1144],
[-0.1782, -0.4026, 0.4985, -0.5878, 0.4833, 0.5260, 0.0710,
0.5741, 0.3153, -0.0460, 0.1516, 0.3593, -0.7491, 0.4448,
-0.6297, 0.2588, 0.4649, -0.2219, -0.5977, 0.3895],
[-0.1429, -0.0431, 0.7642, -0.3143, 0.4679, 0.1455, 0.3204,
-0.2070, -0.1016, -0.4045, -0.3219, -0.2693, -0.6370, -0.0010,
-0.5872, -0.5141, 0.0144, -0.4947, 0.5004, -0.5219]],
[[ 0.0825, 0.1601, 0.3947, 0.4077, 0.2677, -0.4913, -0.3839,
-0.1458, -0.0360, 0.2327, 0.0738, 0.2608, 0.3539, -0.0649,
0.2555, -0.2991, -0.2076, -0.2232, -0.2227, 0.1667],
[ 0.1509, 0.4328, 0.1422, 0.5717, -0.1864, -0.0670, 0.1437,
0.2307, -0.2267, 0.0637, -0.2936, 0.2159, 0.4971, 0.0890,
0.2068, -0.5263, 0.3438, -0.2558, 0.1983, 0.4059],
[ 0.2452, 0.6272, 0.1001, 0.2857, 0.2405, -0.3052, -0.0981,
0.1137, -0.1352, -0.3110, 0.1155, 0.0432, -0.0898, -0.0357,
0.1601, -0.0056, -0.1636, -0.4776, 0.1329, 0.3925]]],
grad_fn=<StackBackward>)
torch.Size([2, 3, 20])
传统RNN模型的优缺点
-
优势:
- 由于内部结构简单,对计算资源要求低,相比之后学的RNN的变体参数总量少了很多,在任务上性能、效果都表现优异。
-
劣势:
- 传统RNN在解决长序列之间的关联时,通过实践证明其表现很差,原因是在进行反向传播时,过长的序列导致梯度计算异常,发生梯度消失或爆炸。
-
什么是梯度消失或爆炸:
-
根据反向传播算法及链式法则,梯度的计算可以简化为以下公式
-
D n = σ ′ ( z 1 ) w 1 ⋅ σ ′ ( z 2 ) w 2 ⋅ . . . ⋅ σ ′ ( z n ) w n D_n = σ'(z_1)w_1 · σ'(z_2)w_2 · ... ·σ'(z_n)w_n Dn=σ′(z1)w1⋅σ′(z2)w2⋅...⋅σ′(zn)wn
-
其中sigmoid的导数值是固定的,在[0, 0.25]之间,而一旦公式中w也小于1,那么通过这样的公式连乘之后,最终的梯度会变得非常非常小,这种现象称作梯度消失,反正,如果认为增大w的值,使其大于1,最终可能造成梯度过大,称作梯度爆炸。
-
如果梯度消失,权重将无法被更新,最终导致训练失败。梯度爆炸带来的梯度过大,大幅更新网络参数,在极端情况下结果会溢出(NaN)
-
LSTM模型
- LSTM也称长短时记忆结构,是传统RNN的变体,与经典RNN相比能够有效捕捉长序列之间的语义关联,缓解梯度消失或爆炸现象。同时LSTM的结构更复杂,它的结构可以分为四个部分解析:
- 遗忘门
- 输入门
- 细胞状态
- 输出