资讯详情

Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context

Transformer最大的问题是没有办法建模超过最大长度的序列,Transformer-XL提出了两个优化点:段级递归和相对位置编码。

段级递归

为解决固定长度的限制,Transformer-XL第一个提出了递归机制,如下图所示segment计算完成后,保存计算结果,计算第二段时,保存第一段hidden state还有第二段hidden state拼接在一起,然后进行后续计算。

在这里插入图片描述 我们看下具体的计算公式,其中h表示的是hidden state, τ \tau τ 表示第 τ \tau τ 个segment,SG函数表示不更新梯度,[]表示向量拼接。

第一个公式的意思是:第一个 τ 1 \tau 1 τ 1个segment第n-1层的hidden state 等于第 τ \tau τ 个segment第n - 1层的hidden state拼接上第 τ 1 \tau 1 τ 1 个segment第n - 1层的hidden state,后续两个公式和vanilla版本相似,但要注意,q是未拼接的hidden state,k、v拼接后,因为q表示当前segment,所以不需要拼接。

可以看出,对于第一个segment来说,hidden state从第二个开始,没有额外的拼接值segment一开始就需要拼接。在论文中,每次都是和最后一segment理论上,拼接可以每次拼接多个segment,第n个segment可以和前n-1个segment进行拼接,但这取决于你自己的显存和一个segment一般来说,没有上图那么短(一个segment也许长度是512),文本本身的上下文一般不超过一个segment的长度。

现代

    def init_mems(self, bsz):         if self.mem_len > 0:             mems = []             for i in range(self.n_layer):                 empty = tf.zeros([self.mem_len, bsz, self.d_model])                 mems.append(empty)              return mems         else:             return None      def _update_mems(self, hids, mems, mlen, qlen):         # does not deal with None         if mems is None:             return None          # mems is not None         assert len(hids) == len(mems), "len(hids) != len(mems)"          # There are `mlen   qlen` steps that can be cached into mems         new_mems = []         end_idx = mlen   tf.math.maximum(0, qlen)         beg_idx = tf.math.maximum(0, end_idx - tf.convert_to_tensor(self.mem_len))         for i in range(len(hids)):             mems[i] = tf.cast(mems[i], dtype=hids[i].dtype)             cat = tf.concat([mems[i], hids[i]], axis=0)             tf.stop_gradient(cat)             new_mems.append(cat[beg_idx:end_idx])          return new_mems 

编码相对位置

Vanilla位置编码是和embedding添加后输入下一层,Transformer-XL输入中没有处理位置代码,而是对attention score修改。

考虑一下,当query与key实际上,计算时不需要知道key在绝对位置编码中,模型实际上需要一个时间线索,即单词的顺序。因此,我知道query与key相对位置。根据上述思路,Transformer-XL改进的三个方面如下:

在新的参数下,每一项都有具体的含义,a表示的是query与key内容相关性,b表示的是query的内容和key位置的相关性,c表示的是query的位置与key内容的相关性,d表示的是quey与key位置的相关性。

综上所述,一个N层,一个N层head的Transformer-XL,其完整步骤如下:

实现代码

class RelativeMultiHeadAttention(layers.Layer):     def __init__(self, num_heads, embed_size):         super(RelativeMultiHeadAttention, self).__init__()         self.num_heads = nu_heads
        self.embed_size = embed_size
        self.hidden_size = embed_size // num_heads

        self.qvk_net = layers.Dense(3 * embed_size)
        self.r_net = layers.Dense(embed_size)
        self.o_net = layers.Dense(embed_size)

        self.layer_norm = layers.LayerNormalization()

    def _rel_shift(self, x):
        x_size = tf.shape(x)
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k + 1, batch_size, num_heads)
        x = tf.pad(x, [[0, 0], [1, 0], [0, 0], [0, 0]])
        # shape:(seq_len_q, seq_len_k + 1, batch_size, num_heads)=>(seq_len_q + 1, seq_len_k, batch_size, num_heads)
        x = tf.reshape(x, (x_size[0] + 1, x_size[1], x_size[2], x_size[3]))
        # shape:(seq_len_q + 1, seq_len_k, batch_size, num_heads)=>(seq_len_q, seq_len_k, batch_size, num_heads)
        x = tf.slice(x, [0, 1, 0, 0], [-1, -1, -1, -1])
        return x

    # w表示token embedding,r表示relative position embedding
    # r_w_bias表示uT,r_r_bias表示vT,形状和w的形状一致
    def __call__(self, w, r, r_w_bias, r_r_bias, mask=None, mems=None, *args, **kwargs):
        # w
        # shape:(seq_len, batch_size, embed_size)
        # r
        # shape:(seq_len, 1, embed_size)
        seq_len_q, batch_size, seq_len_r = tf.shape(w)[0], tf.shape(w)[1], tf.shape(r)[0]
        if mems is not None:
            cat = tf.concat([mems, w], axis=0)
            w_heads = self.qvk_net(cat)
            # 有mems时:
            # w_head_q
            # shape:(seq_len_q, batch_size, embed_size)
            # w_head_k, w_head_v
            # shape:(seq_len_k, batch_size, embed_size),其中seq_len_k = seq_len_q + seq_len_mems
            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
            w_head_q = w_head_q[-seq_len_q:]
            r_head_k = self.r_net(r)
        else:
            w_heads = self.qvk_net(w)
            # 没有mems时:(seq_len_q = seq_len)
            # w_head_q, w_head_k, w_head_v
            # shape:(seq_len_q, batch_size, embed_size)
            w_head_q, w_head_k, w_head_v = tf.split(w_heads, 3, axis=-1)
            r_head_k = self.r_net(r)
        seq_len_k = tf.shape(w_head_k)[0]
        # w_head_q
        # shape:(seq_len_q, batch_size, embed_size)=>(seq_len_q, batch_size, num_heads, hidden_size)
        # w_head_k, w_head_v
        # shape:(seq_len_k, batch_size, embed_size)=>(seq_len_k, batch_size, num_heads, hidden_size)
        # r_head_k
        # shape:(seq_len_r, 1, embed_size)=>(seq_len_r, num_heads, hidden_size)
        w_head_q = tf.reshape(w_head_q, (seq_len_q, batch_size, self.num_heads, self.hidden_size))
        w_head_k = tf.reshape(w_head_k, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
        w_head_v = tf.reshape(w_head_v, (seq_len_k, batch_size, self.num_heads, self.hidden_size))
        r_head_k = tf.reshape(r_head_k, (seq_len_r, self.num_heads, self.hidden_size))
        # 计算A+C两项,(w_head_q + r_w_bias) * w_head_k = (qT + uT) * k
        # w_head_q
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        # r_w_bias
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        # w_head_k
        # shape:(seq_len_k, batch_size, num_heads, hidden_size)
        wr_head_q = w_head_q + r_w_bias
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        AC = tf.einsum("ibnh,jbnh->ijbn", wr_head_q, w_head_k)
        # 计算B+D两项,(w_head_q + r_r_bias) * r_head_k = (qT + vT) * r
        wr_head_r = w_head_q + r_r_bias
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        BD = tf.einsum("ibnh,jnh->ijbn", wr_head_r, r_head_k)
        BD = self.rel_shift(BD)
        # 计算attention_score,attention_score = softmax((A+B+C+D)/dk[+mask])
        attention_score = (AC + BD) / tf.sqrt(self.hidden_size)
        # 如果有mask
        if mask is not None:
            attention_score += (mask * 1e-9)
        # shape:(seq_len_q, seq_len_k, batch_size, num_heads)
        attention_score = tf.nn.softmax(attention_score, axis=1)
        # 计算attention,attention = attention_score * v
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)
        attention = tf.einsum("ijbn,jbnh->ibnh", attention_score, w_head_v)
        # shape:(seq_len_q, batch_size, num_heads, hidden_size)=>(seq_len_q, batch_size, embed_size)
        attention = tf.reshape(attention, (seq_len_q, batch_size, self.embed_size))
        attention = self.o_net(attention)
        # residual connection
        output = attention + w
        # layer normalization
        output = self.layer_norm(output)
        return output

标签: bsz808a振动传感器变送器

锐单商城拥有海量元器件数据手册IC替代型号,打造 电子元器件IC百科大全!

 锐单商城 - 一站式电子元器件采购平台  

 深圳锐单电子有限公司