Transformer最大的问题是没有办法建模超过最大长度的序列,Transformer-XL提出了两个优化点:段级递归和相对位置编码。
段级递归
为解决固定长度的限制,Transformer-XL第一个提出了递归机制,如下图所示segment计算完成后,保存计算结果,计算第二段时,保存第一段hidden state还有第二段hidden state拼接在一起,然后进行后续计算。
我们看下具体的计算公式,其中h表示的是hidden state, τ \tau τ 表示第 τ \tau τ 个segment,SG函数表示不更新梯度,[]表示向量拼接。
第一个公式的意思是:第一个 τ 1 \tau 1 τ 1个segment第n-1层的hidden state 等于第 τ \tau τ 个segment第n - 1层的hidden state拼接上第 τ 1 \tau 1 τ 1 个segment第n - 1层的hidden state,后续两个公式和vanilla版本相似,但要注意,q是未拼接的hidden state,k、v拼接后,因为q表示当前segment,所以不需要拼接。
可以看出,对于第一个segment来说,hidden state从第二个开始,没有额外的拼接值segment一开始就需要拼接。在论文中,每次都是和最后一segment理论上,拼接可以每次拼接多个segment,第n个segment可以和前n-1个segment进行拼接,但这取决于你自己的显存和一个segment一般来说,没有上图那么短(一个segment也许长度是512),文本本身的上下文一般不超过一个segment的长度。
实现代码
def init_mems(self, bsz): if self.mem_len > 0: mems = [] for i in range(self.n_layer): empty = tf.zeros([self.mem_len, bsz, self.d_model]) mems.append(empty) return mems else: return None def _update_mems(self, hids, mems, mlen, qlen): # does not deal with None if mems is None: return None # mems is not None assert len(hids) == len(mems), "len(hids) != len(mems)" # There are `mlen qlen` steps that can be cached into mems new_mems = [] end_idx = mlen tf.math.maximum(0, qlen) beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len)) for i in range(len(hids)): mems[i] = tf.cast(mems[i], dtype=hids[i].dtype) cat = tf.concat([mems[i], hids[i]], axis=0) tf.stop_gradient(cat) new_mems.append(cat[beg_idx:end_idx]) return new_mems
编码相对位置
Vanilla位置编码是和embedding添加后输入下一层,Transformer-XL输入中没有处理位置代码,而是对attention score修改。
考虑一下,当query与key实际上,计算时不需要知道key在绝对位置编码中,模型实际上需要一个时间线索,即单词的顺序。因此,我知道query与key相对位置。根据上述思路,Transformer-XL改进的三个方面如下:
在新的参数下,每一项都有具体的含义,a表示的是query与key内容相关性,b表示的是query的内容和key位置的相关性,c表示的是query的位置与key内容的相关性,d表示的是quey与key位置的相关性。
综上所述,一个N层,一个N层head的Transformer-XL,其完整步骤如下:
实现代码
class RelativeMultiHeadAttention(layers.Layer): def __init__(self, num_heads, embed_size): super(RelativeMultiHeadAttention, self).__init__() self.num_heads = nu_heads
self.embed_size = embed_size
self.hidden_size = embed_size // num_heads
self.qvk_net = layers.Dense(3 * embed_size)
self.r_net = layers.Dense(embed_size)
self.o_net = layers.Dense(embed_size)
self.layer_norm = layers.LayerNormalization()
def _rel_shift(self, x):
x_size = tf.shape(x)
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k + 1, batch_size, num_heads)
x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
# shape:(seq_len_q, seq_len_k + 1, batch_size, num_heads)=>(seq_len_q + 1, seq_len_k, batch_size, num_heads)
x = tf.reshape(x, (x_size[0] + 1, x_size[1], x_size[2], x_size[3]))
# shape:(seq_len_q + 1, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k, batch_size, num_heads)
x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
return x
# w表示token embedding,r表示relative position embedding
# r_w_bias表示uT,r_r_bias表示vT,形状和w的形状一致
def __call__(self, w, r, r_w_bias, r_r_bias, mask=None, mems=None, *args, **kwargs):
# w
# shape:(seq_len, batch_size, embed_size)
# r
# shape:(seq_len, 1, embed_size)
seq_len_q, batch_size, seq_len_r = tf.shape(w)[0], tf.shape(w)[1], tf.shape(r)[0]
if mems is not None:
cat = tf.concat([mems, w], axis=0)
w_heads = self.qvk_net(cat)
# 有mems时:
# w_head_q
# shape:(seq_len_q, batch_size, embed_size)
# w_head_k, w_head_v
# shape:(seq_len_k, batch_size, embed_size),其中seq_len_k = seq_len_q + seq_len_mems
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
w_head_q = w_head_q[-seq_len_q:]
r_head_k = self.r_net(r)
else:
w_heads = self.qvk_net(w)
# 没有mems时:(seq_len_q = seq_len)
# w_head_q, w_head_k, w_head_v
# shape:(seq_len_q, batch_size, embed_size)
w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
r_head_k = self.r_net(r)
seq_len_k = tf.shape(w_head_k)[0]
# w_head_q
# shape:(seq_len_q, batch_size, embed_size)=>(seq_len_q, batch_size, num_heads, hidden_size)
# w_head_k, w_head_v
# shape:(seq_len_k, batch_size, embed_size)=>(seq_len_k, batch_size, num_heads, hidden_size)
# r_head_k
# shape:(seq_len_r, 1, embed_size)=>(seq_len_r, num_heads, hidden_size)
w_head_q = tf.reshape(w_head_q, (seq_len_q, batch_size, self.num_heads, self.hidden_size))
w_head_k = tf.reshape(w_head_k, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
w_head_v = tf.reshape(w_head_v, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
r_head_k = tf.reshape(r_head_k, (seq_len_r, self.num_heads, self.hidden_size))
# 计算A+C两项,(w_head_q + r_w_bias) * w_head_k = (qT + uT) * k
# w_head_q
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
# r_w_bias
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
# w_head_k
# shape:(seq_len_k, batch_size, num_heads, hidden_size)
wr_head_q = w_head_q + r_w_bias
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
AC = tf.einsum("ibnh,jbnh->ijbn", wr_head_q, w_head_k)
# 计算B+D两项,(w_head_q + r_r_bias) * r_head_k = (qT + vT) * r
wr_head_r = w_head_q + r_r_bias
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
BD = tf.einsum("ibnh,jnh->ijbn", wr_head_r, r_head_k)
BD = self.rel_shift(BD)
# 计算attention_score,attention_score = softmax((A+B+C+D)/dk[+mask])
attention_score = (AC + BD) / tf.sqrt(self.hidden_size)
# 如果有mask
if mask is not None:
attention_score += (mask * 1e-9)
# shape:(seq_len_q, seq_len_k, batch_size, num_heads)
attention_score = tf.nn.softmax(attention_score, axis=1)
# 计算attention,attention = attention_score * v
# shape:(seq_len_q, batch_size, num_heads, hidden_size)
attention = tf.einsum("ijbn,jbnh->ibnh", attention_score, w_head_v)
# shape:(seq_len_q, batch_size, num_heads, hidden_size)=>(seq_len_q, batch_size, embed_size)
attention = tf.reshape(attention, (seq_len_q, batch_size, self.embed_size))
attention = self.o_net(attention)
# residual connection
output = attention + w
# layer normalization
output = self.layer_norm(output)
return output