模型概述
XLNet 提出了一种有趣的观点,将当前的预训练模型分为两类 AR (Auto Regression,自回归) 和 AE (Auto Encoder,自编码器)。 XLNet 将 AR 和 AE 结合两种方法的优点,XLNet 使用了 PLM(Permutation Language Model,实现这一目标的排列组合语言模型。
此外,由于使用排列组合模型,使用Transformer注意力会导致不知道预测哪一个token,XLNet目标位置感知是通过双流自注机制实现的。
模型优化
语言模型的排列组合
,也就是说,常说的自左向右的语言模型,或者反过来,就是根据下面预测前面的单词。

自回归语言模型有优缺点。缺点是只能使用上下信息,不能同时使用上下信息。它的优点实际上与下游有关NLP任务相关,如生成类NLP在实际生成内容时,任务,如文本摘要、机器翻译等,是从左到右自回归语言模型自然匹配的过程。
GPT 是典型的自回归语言模型。ELMO虽然它似乎使用了上面和下面,但它本质上仍然是自我回归LM,这与如何实现模型有关。ELMO是两个方向(从左到右,从右到左),但是有两个方向的自回归LM,然后把LSTM两个方向的隐节点状态拼接在一起,反映双向语言模型。因此,它实际上是两种自回归语言模型的拼接,本质上仍然是自回归语言模型。
。它可以更自然地融入双向语言模型,同时看到被预测单词的上下文。优缺点恰到好处,自我回归LM另一方面,它可以自然地融入双向语言模型,同时看到被预测单词的上下文。缺点主要介绍在输入侧[Mask]标记,导致预训练阶段和Fine-tuning由于阶段不一致的问题,Fine-tuning看不到阶段[Mask]标记的。
BERT随机输入XMask掉一些单词,然后根据上下文单词预测这些单词Mask掉的单词。
。这样,在预测单词时,可以同时使用双向信息,学习单词之间的依赖。
XLNet 中通过 Attention Mask 实现 PLM,而无需真正修改句子 token 顺序。例如,原句是 如果随机生成的序列[1、2、3、4] 输入[3,2,4,1] XLNet 句子还是 [1,2,3,4],但掩码需要修改成下图。
图中的掩码矩阵,红色表示不遮掩,白色表示遮掩。第 1 行表示 token 1 在序列中,掩码 3241在最后,可以看到前面的324,所以第一行的第2、3、4圈都是红色的,说明不遮掩。第2行表示token2的掩码位于序列3241中的第二个,它只能看到前面的3,所以第二行的第三个圈是红色的。第3行表示token3的掩码在序列3241中排名第一,因为它前面没有token,因此它看不到任何的token,所以四个圈都应该是白色的,这意味着掩盖。
双流自注:
XLNet 打乱句子的顺序,然后在预测时 token 的位置信息会非常重要,同时在预测的时候也必须将 token 掩盖内容信息 (否则输入包含要预测的内容信息,模型无法学习知识)。也就是说 XLNet 需要看到 token 位置信息,但看不到 token 所以 XLNet 采用了两个 Stream 实现这一目的:
,对于每一个 token,其对应的 Query Stream 只包含了该 token 注意位置信息 token 原句的位置信息不是重新排列的位置信息。 ,对于每一个 token,其对应的 Content Stream 包含了该 token 内容信息。
: Query Stream 用 g表示,Content Stream 用 h 表示,使用 Query Stream 预测要预测的位置时,Q (Query) 向量是用 g 该位置的位置信息包含在计算中 K (Key) 和 V (Value) 是用 h 计算的,包含其他 token 内容信息。下图显示了如何通过当前层 g 计算下一层 g 图中的排列是过程 计算[3,2,4,1] token 是 1。
在计算中可以看到 token 1 的 Q 只用于向量 token 1 的 Query Stream g,也就是说,只得到模型 token 1 位置信息。而向量 K,V 使用 token 3, 2, 4 计算,可以得到模型 token 3, 2, 4 内容信息。因为 token 1 是排列 [3,2,4,1] 最后一个。这个过程的掩码矩阵和上一节一样 ,对角线为白色,即掩盖当前预测位置的内容信息 h。
: Content Stream 包含了 token 因为 XLNet 层数多,需要将 token 将内容传递到下一层。 Q, K, V 都是利用 h 计算的。Content Stream 计算如下图所示。
可以看出,下一层的计算 h也会在1点使用 token 1 现在的内容信息,这样就可以了 token 内容传递到下一层,但要注意 XLNet 只用于预测 g (Query Stream)。计算 Content Stream 如下图所示。
和 Query Stream 掩码矩阵的区别在于对角线,Content Stream 不掩盖对角线,使目前 token 信息可以传递到下一层。
将 Query Stream 和 Content Stream 组合在一起,如下图所示。
图中底层为输入层,其中 e(x) 最初是单词的词向量,而 w 最初是训练的向量。代码如下。如果你什么都不懂,可以留言。
class TFXLNetRelativeAttention(layers.Layer): def __init__(self, config, **kwargs): super().__init__(**kwargs) if config.d_model % config.n_head != 0: raise ValueError( f"The hidden size ({config.d_model}) is not a multiple of the number of attention " f"heads ({config.n_head}" ) self.n_head = config.n_head self.d_head = config.d_head self.d_model = config.d_model self.scale = 1 / (config.d_head ** 0.5) self.initializer_range = config.initializer_range self.output_attentions = config.output_attentions self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm") self.dropout = layers.Dropout(config.dropout) def rel_attn_core( self, q_head, k_head_h, v_head_h, k_head_r, seg_mat, attn_mask, head_mask, output_attentions, training=False ): ac = tf.einsum("ibnd,jbnd->ijbn", q_head self.r_w_bias, k_head_h) bd = tf.einsum("ibnd,jbnd->ijbn", q_head self.r_r_bias, k_head_r) bd = self.rel_shift(bd, klen=shape_list(ac[1])) if seg_mat is None: ef = 0 else: ef = tf.einsum("ibnd,snd->ibns", q_head self.r_s_bias, self.seg_embed) ef = tf.einsum("ijbs,ibns->ijbn", seg_at, ef)
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training)
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if output_attentions:
return attn_vec, attn_prob
return attn_vec
def rel_shift(self, x, klen=-1):
x_size = shape_list(x)
x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
x = x[1:, ...]
x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
x = x[:, 0:klen, :, :]
return x
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
)
self.k = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
)
self.v = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
)
self.o = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
)
self.r = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
)
self.r_r_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
)
self.r_s_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
)
self.r_w_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
)
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
def post_attention(self, h, attn_vec, residual=True, training=False):
# shape: (..., n_head, d_head) - > (..., d_model)
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
# 残差连接
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def call(
self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
if g is not None:
if mems is not None and len(shape_list(mems)) > 1:
# shape: (mlen+qlen, bsz, d_model)
cat = tf.concat([mems, h], axis=0)
else:
# shape: (qlen, bsz, d_model)
cat = h
# h stream
# [qlen, bsz, n_head, d_head]
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec_h = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
output_h = self.post_attention(h, attn_vec_h, training=training)
# g stream
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
output_g = self.post_attention(g, attn_vec_g, training=training)
if output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec, attn_prob = attn_vec
output_h = self.post_attention(h, attn_vec, training=training)
output_g = None
outputs = (output_h, output_g)
if output_attentions:
outputs = outputs + (attn_prob, )
return outputs
XLNet 将句子重新排列,然后根据排列后的顺序使用 AR 方式预测,但是由于句子是随机排列的,会导致优化比较困难且收敛速度慢。因此 XLNet 采用了 Partial Prediction (部分预测) 的方式进行训练,
例如 K=4,就是只预测最后 1/4 的 token。给定句子 [1,2,3,4,5,6,7,8] 和一种随机排列 [2,8,3,4,5,1,7,6],则只预测 7 和 6。论文中训练 XLNet-Large 时使用的 K 为 6,大约是预测末尾 14.3%的 token。
XLNet还将transformer-xl的两个最重要的技术点应用了进来,即相对位置编码与片段循环机制。具体内容详见TranformerXL这一部分。
模型代码
from dataclasses import dataclass
from typing import Optional, List, Tuple
import tensorflow as tf
from tensorflow.keras import layers
from transformers import shape_list
from transformers.activations_tf import get_tf_activation
from transformers.modeling_tf_utils import get_initializer
from transformers.tf_utils import stable_softmax
from transformers.utils import ModelOutput
class TFXLNetModel(tf.keras.Model):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.xlnet = TFXLNetMainLayer(config, name="xlnet")
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False
):
outputs = self.xlnet(
input_ids=input_ids,
attention_mask=attention_mask,
mems=mems,
perm_mask=perm_mask,
target_mapping=target_mapping,
token_type_ids=token_type_ids,
input_mask=input_mask,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
use_mems=use_mems,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
training=training
)
return outputs
class TFXLNetMainLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.config = config
self.output_hidden_states = config.output_hidden_states
self.output_attentions = config.output_attentions
self.return_dict = config.return_dict
self.mem_len = config.mem_len
self.reuse_len = config.reuse_len
self.d_model = config.d_model
self.same_length = config.same_length
self.attn_type = config.attn_type
self.bi_data = config.bi_data
self.clamp_len = config.clamp_len
self.n_layer = config.n_layer
self.use_bfloat16 = config.use_bfloat16
self.initializer_range = config.initializer_range
self.word_embedding = TFSharedEmbeddings(
config.vocab_size,
config.d_model,
initializer_range=config.initializer_range,
name="word_embedding"
)
self.layers = [TFXLNetLayer(config, name=f"layer_._{i}") for i in range(config.n_layer)]
self.dropout = layers.Dropout(config.dropout)
self.use_mems_eval = config.use_mems_eval
self.use_mems_train = config.use_mems_train
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.mask_emb = self.add_weight(
shape=(1, 1, self.d_model),
initializer=initializer,
trainable=True,
name="mask_emb"
)
def create_mask(self, qlen, mlen):
"""
same_length=False: same_length=True:
<mlen > < qlen > <mlen > < qlen >
^ [0 0 0 0 0 1 1 1 1] [0 0 0 0 0 1 1 1 1]
[0 0 0 0 0 0 1 1 1] [1 0 0 0 0 0 1 1 1]
qlen [0 0 0 0 0 0 0 1 1] [1 1 0 0 0 0 0 1 1]
[0 0 0 0 0 0 0 0 1] [1 1 1 0 0 0 0 0 1]
v [0 0 0 0 0 0 0 0 0] [1 1 1 1 0 0 0 0 0]
"""
attn_mask = tf.ones([qlen, qlen])
mask_u = tf.linalg.band_part(attn_mask, 0, -1)
mask_dia = tf.linalg.band_part(attn_mask, 0, 0)
attn_mask_pad = tf.zeros([qlen, mlen])
ret = tf.concat([attn_mask_pad, mask_u - mask_dia], axis=1)
if self.same_length:
mask_l = tf.linalg.band_part(attn_mask, -1, 0)
ret = tf.concat([ret[:, : qlen] + mask_l - mask_dia, ret[:, qlen:]], axis=1)
return ret
def relative_positional_encoding(self, qlen, klen, bsz=None):
freq_seq = tf.range(0, self.d_model, 2.0)
inv_freq = 1 / (10000 ** (freq_seq / self.d_model))
if self.attn_type == "bi":
beg, end = klen, -qlen
elif self.attn_type == "uni":
beg, end = klen, -1
else:
raise ValueError(f"Unknown `attn_type` {self.attn_type}.")
if self.bi_data:
fwd_pos_seq = tf.range(beg, end, -1.0)
bwd_pos_seq = tf.range(-beg, -end, 1.0)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
bwd_pos_seq = tf.clip_by_value(bwd_pos_seq, -self.clamp_len, self.clamp_len)
if bsz is not None:
if bsz % 2 != 0:
raise ValueError(f"With bi_data, the batch size {bsz} should be divisible by 2")
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz // 2)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq, bsz // 2)
else:
fwd_pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq)
bwd_pos_emb = self.positional_embedding(bwd_pos_seq, inv_freq)
pos_emb = tf.concat([fwd_pos_emb, bwd_pos_emb], axis=1)
else:
fwd_pos_seq = tf.range(beg, end, -1.0)
if self.clamp_len > 0:
fwd_pos_seq = tf.clip_by_value(fwd_pos_seq, -self.clamp_len, self.clamp_len)
pos_emb = self.positional_embedding(fwd_pos_seq, inv_freq, bsz)
return pos_emb
@staticmethod
def positional_embedding(pos_seq, inv_freq, bsz=None):
sinusoid_inp = tf.einsum("i,d->id", pos_seq, inv_freq)
pos_emb = tf.concat([tf.sin(sinusoid_inp), tf.cos(sinusoid_inp)], axis=-1)
pos_emb = pos_emb[:, None, :]
if bsz is not None:
pos_emb = tf.tile(pos_emb, [1, bsz, 1])
return pos_emb
def cache_mem(self, curr_out, prev_mem):
if self.reuse_len is not None and self.reuse_len > 0:
curr_out = curr_out[: self.reuse_len]
if self.mem_len is None or self.mem_len == 0:
cutoff = 0
else:
cutoff = -self.mem_len
if prev_mem is None:
new_mem = curr_out[cutoff:]
else:
new_mem = tf.concat([prev_mem, curr_out], 0)[cutoff: ]
return tf.stop_gradient(new_mem)
def call(
self,
input_ids=None,
attention_mask=None,
mems=None,
perm_mask=None,
target_mapping=None,
token_type_ids=None,
input_mask=None,
head_mask=None,
inputs_embeds=None,
use_mems=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
training=False,
):
if training and use_mems is None:
use_mems = self.use_memes_train
else:
use_mems = self.use_mems_eval
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_ids = tf.transpose(input_ids, perm=(1, 0))
qlen, bsz = shape_list(input_ids)[: 2]
elif inputs_embeds is not None:
inputs_embeds = tf.transpose(inputs_embeds, perm=(1, 0, 2))
qlen, bsz = shape_list(inputs_embeds)[: 2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
token_type_ids = tf.transpose(token_type_ids, perm=(1, 0)) if token_type_ids is not None else None
# input_mask中的0、1与attention_mask中0、1含义相反
# input_mask中0表示不遮掩,1表示遮掩
input_mask = tf.transpose(token_type_ids, perm=(1, 0)) if input_mask is not None else None
# attention_mask中1表示不遮掩,0表示遮掩
attention_mask = tf.transpose(attention_mask, perm=(1, 0)) if attention_mask is not None else None
# xlnet实现随机排列是通过perm_mask来实现的
# perm_mask[k, i, j]=1表示在batch k中第i个单词可以看到第j个单词
# 比如序列1、2、3、4的随机排序3、2、4、1对应的perm_mask
# [[1 0 0 0]
# [1 1 0 1]
# [1 1 1 1]
# [1 0 0 1]]
# perm_mask中第3行只有第3列为1,因为3在最前面,只能看到自己
# 第4行中第2、3、4列都为1,也就是说4能够看到2、3和自己
perm_mask = tf.transpose(perm_mask, perm=(1, 2, 0)) if perm_mask is not None else None
# target_mapping[k, i, j]表示在batch k中第i个预测的单词在序列的第j个位置
# 用于预训练任务中,在下游任务中应设置为None
target_mapping = tf.transpose(target_mapping, perm=(1, 2, 0)) if target_mapping is not None else None
mlen = shape_list(mems[0])[0] if mems is not None and mems[0] is not None else 0
klen = mlen + qlen
if self.attn_type == "uni":
attn_mask = self.create_mask(qlen, mlen)
attn_mask = attn_mask[:, :, None, None]
elif self.attn_type == "bi":
attn_mask = None
else:
raise ValueError(f"Unsupported attention type: {self.attn_type}")
if input_mask is None and attention_mask is not None:
one_cst = tf.constant(1.0)
input_mask = 1.0 - tf.cast(attention_mask, dtype=one_cst.dtype)
if input_mask is not None and perm_mask is not None:
data_mask = input_mask[None, :, :] + perm_mask
elif input_mask is not None and perm_mask is None:
data_mask = input_mask[None, :, :]
elif input_mask is None and perm_mask is not None:
data_mask = perm_mask
else:
data_mask = None
if data_mask is not None:
if mlen > 0:
mems_mask = tf.zeros([shape_list(data_mask)[0], mlen, bsz])
data_mask = tf.concat([mems_mask, data_mask], axis=1)
if attn_mask is None:
attn_mask = data_mask[:, :, :, None]
else:
attn_mask += data_mask[:, :, :, None]
if attn_mask is not None:
attn_mask = tf.cast(attn_mask > 0, dtype=attn_mask.dtype)
if attn_mask is not None:
# non_tgt_mask对比attn_mask的对角线,由1变成0
# 也就是说non_tgt_mask可以看到自身
# non_tgt_mask参与计算content stream
non_tgt_mask = -tf.eye(qlen)
if mlen > 0:
non_tgt_mask = tf.concat([tf.zeros([qlen, qlen]), non_tgt_mask], axis=-1)
non_tgt_mask = tf.cast((attn_mask + non_tgt_mask[:, :, None, None]) > 0, dtype=non_tgt_mask.dtype)
else:
non_tgt_mask = None
# Word embedding
if inputs_embeds is not None:
word_emb_k = inputs_embeds
else:
word_emb_k = self.word_embedding(input_ids)
# output_h为content stream,表示初始输入的词向量
# shape:(qlen, bsz, d_model)
output_h = self.dropout(word_emb_k, training=training)
if target_mapping is not None:
word_emb_q = tf.tile(self.mask_emb, [shape_list(target_mapping)[0], bsz, 1])
# output_g为query stream,表示初始输入的位置向量
output_g = self.dropout(word_emb_q, training=training)
else:
output_g = None
# Segment embedding
if token_type_ids is not None:
if mlen > 0:
mem_pad = tf.zeros([mlen, bsz], dtype=token_type_ids.type)
cat_ids = tf.concat([mem_pad, token_type_ids], 0)
else:
cat_ids = token_type_ids
# 1表示token_type_ids中位置i的token参与到位置j的计算时,对应的分句不是同一个分句
# shape: (qlen, klen, bsz)
seg_mat = tf.cast(
tf.logical_not(tf.equal(token_type_ids[:, None], cat_ids[None, :])),
dtype=token_type_ids.dtype
)
# 将token_type_ids转化成one-hot形式
# shape: (qlen, klen, bsz, 2)
seg_mat = tf.one_hot(seg_mat, 2)
else:
seg_mat = None
# Position embedding
pos_emb = self.relative_positional_encoding(qlen, klen, bsz=bsz)
pos_emb = self.dropout(pos_emb, training=training)
if head_mask is not None:
raise NotImplementedError
else:
head_mask = [None] * self.n_layer
new_mems = ()
if mems is None:
mems = [None] * len(self.layers)
attentions = [] if output_attentions else None
hidden_states = [] if output_hidden_states else None
for i, layer_module in enumerate(self.layers):
if use_mems:
new_mems = new_mems + (self.cache_mem(output_h, mems[i]))
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
outputs = layer_module(
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems[i],
target_mapping,
head_mask[i],
output_attentions,
training=training
)
output_h, output_g = outputs[: 2]
if output_attentions:
attentions.append(outputs[2])
if output_hidden_states:
hidden_states.append((output_h, output_g) if output_g is not None else output_h)
output = self.dropout(output_g if output_g is not None else output_h, training=training)
output = tf.transpose(output, perm=(1, 0, 2))
if not use_mems:
new_mems = None
if output_hidden_states:
if output_g is not None:
hidden_states = tuple(tf.transpose(h, perm=(1, 0, 2)) for hs in hidden_states for h in hs)
else:
hidden_states = tuple(tf.transpose(hs, perm=(1, 0, 2)) for hs in hidden_states)
if output_attentions:
if target_mapping is not None:
attentions = tuple(
tuple(tf.transpose(attn_stream, perm=(2, 3, 0, 1)) for attn_stream in t) for t in attentions
)
else:
attentions = tuple(tf.transpose(t, perm=(2, 3, 0, 1)) for t in attentions)
if not return_dict:
return tuple(v for v in [output, new_mems, hidden_states, attentions] if v is not None)
return TFXLNetModelOutput(
last_hidden_state=output, mems=new_mems, hidden_states=hidden_states, attentions=attentions
)
class TFSharedEmbeddings(layers.Layer):
def __init__(self, vocab_size, hidden_size, initializer_range, **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.initializer_range = hidden_size ** -0.5 if initializer_range is None else initializer_range
def build(self, input_shape):
self.weight = self.add_weight(
"weight", shape=[self.vocab_size, self.hidden_size], initializer=get_initializer(self.initializer_range)
)
super().build(input_shape)
def call(self, inputs, mode="embedding"):
if mode == "embedding":
return self._embedding(inputs)
elif mode == "linear":
return self._linear(inputs)
else:
raise ValueError(f"mode {mode} is not valid.")
def _embedding(self, input_ids):
return tf.gather(self.weight, input_ids)
def _linear(self, inputs):
first_dims = shape_list(inputs)[-1]
x = tf.reshape(inputs, [-1, self.hidden_size])
logits = tf.matmul(x, self.weight, transpose_b=True)
return tf.reshape(logits, [first_dims] + [self.vocab_size])
class TFXLNetLayer(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.rel_attn = TFXLNetRelativeAttention(config, name="rel_attn")
self.ff = TFXLNetFeedForward(config, name="ff")
self.dropout = layers.Dropout(config.dropout)
def call(
self,
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
outputs = self.rel_attn(
output_h,
output_g,
non_tgt_mask,
attn_mask,
pos_emb,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=training
)
output_h, output_g = outputs[: 2]
if output_g is not None:
output_g = self.ff(output_g, training=training)
output_h = self.ff(output_h, training=training)
outputs = (output_h, output_g) + outputs[2:]
return outputs
class TFXLNetRelativeAttention(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
if config.d_model % config.n_head != 0:
raise ValueError(
f"The hidden size ({config.d_model}) is not a multiple of the number of attention "
f"heads ({config.n_head}"
)
self.n_head = config.n_head
self.d_head = config.d_head
self.d_model = config.d_model
self.scale = 1 / (config.d_head ** 0.5)
self.initializer_range = config.initializer_range
self.output_attentions = config.output_attentions
self.layer_norm = layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.dropout = layers.Dropout(config.dropout)
def rel_attn_core(
self,
q_head,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask,
head_mask,
output_attentions,
training=False
):
ac = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_w_bias, k_head_h)
bd = tf.einsum("ibnd,jbnd->ijbn", q_head + self.r_r_bias, k_head_r)
bd = self.rel_shift(bd, klen=shape_list(ac[1]))
if seg_mat is None:
ef = 0
else:
ef = tf.einsum("ibnd,snd->ibns", q_head + self.r_s_bias, self.seg_embed)
ef = tf.einsum("ijbs,ibns->ijbn", seg_mat, ef)
attn_score = (ac + bd + ef) * self.scale
if attn_mask is not None:
if attn_mask.dtype == tf.float16 or attn_mask.dtype == tf.bfloat16:
attn_score = attn_score - 65500 * attn_mask
else:
attn_score = attn_score - 1e30 * attn_mask
attn_prob = stable_softmax(attn_score, axis=1)
attn_prob = self.dropout(attn_prob, training=training)
if head_mask is not None:
attn_prob = attn_prob * head_mask
attn_vec = tf.einsum("ijbn,jbnd->ibnd", attn_prob, v_head_h)
if output_attentions:
return attn_vec, attn_prob
return attn_vec
def rel_shift(self, x, klen=-1):
x_size = shape_list(x)
x = tf.reshape(x, (x_size[1], x_size[0], x_size[2], x_size[3]))
x = x[1:, ...]
x = tf.reshape(x, (x_size[0], x_size[1] - 1, x_size[2], x_size[3]))
x = x[:, 0:klen, :, :]
return x
def build(self, input_shape):
initializer = get_initializer(self.initializer_range)
self.q = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="q"
)
self.k = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="k"
)
self.v = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="v"
)
self.o = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="o"
)
self.r = self.add_weight(
shape=(self.d_model, self.n_head, self.d_head), initializer=initializer, trainable=True, name="r"
)
self.r_r_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_r_bias"
)
self.r_s_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_s_bias"
)
self.r_w_bias = self.add_weight(
shape=(self.n_head, self.d_head), initializer="zeros", trainable=True, name="r_w_bias"
)
self.seg_embed = self.add_weight(
shape=(2, self.n_head, self.d_head), initializer=initializer, trainable=True, name="seg_embed"
)
super().build(input_shape)
def post_attention(self, h, attn_vec, residual=True, training=False):
# shape: (..., n_head, d_head) - > (..., d_model)
attn_out = tf.einsum("ibnd,hnd->ibh", attn_vec, self.o)
attn_out = self.dropout(attn_out, training=training)
# 残差连接
if residual:
attn_out = attn_out + h
output = self.layer_norm(attn_out)
return output
def call(
self,
h,
g,
attn_mask_h,
attn_mask_g,
r,
seg_mat,
mems,
target_mapping,
head_mask,
output_attentions,
training=False
):
if g is not None:
if mems is not None and len(shape_list(mems)) > 1:
# shape: (mlen+qlen, bsz, d_model)
cat = tf.concat([mems, h], axis=0)
else:
# shape: (qlen, bsz, d_model)
cat = h
# h stream
# [qlen, bsz, n_head, d_head]
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec_h = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_h, attn_prob_h = attn_vec_h
output_h = self.post_attention(h, attn_vec_h, training=training)
# g stream
q_head_g = tf.einsum("ibh,hnd->ibnd", g, self.q)
# target_mapping
# shape: (num_predict, qlen, bsz)
# 一般而言num_predict = qlen
if target_mapping is not None:
q_head_g = tf.einsum("mbnd,mlb->lbnd", q_head_g, target_mapping)
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
attn_vec_g = tf.einsum("lbnd,mlb->mbnd", attn_vec_g, target_mapping)
else:
attn_vec_g = self.rel_attn_core(
q_head_g,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_g,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec_g, attn_prob_g = attn_vec_g
output_g = self.post_attention(g, attn_vec_g, training=training)
if output_attentions:
attn_prob = attn_prob_h, attn_prob_g
else:
if mems is not None and len(shape_list(mems)) > 1:
cat = tf.concat([mems, h], axis=0)
else:
cat = h
q_head_h = tf.einsum("ibh,hnd->ibnd", h, self.q)
k_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.k)
v_head_h = tf.einsum("ibh,hnd->ibnd", cat, self.v)
k_head_r = tf.einsum("ibh,hnd->ibnd", r, self.r)
attn_vec = self.rel_attn_core(
q_head_h,
k_head_h,
v_head_h,
k_head_r,
seg_mat,
attn_mask_h,
head_mask,
output_attentions,
training=training
)
if output_attentions:
attn_vec, attn_prob = attn_vec
output_h = self.post_attention(h, attn_vec, training=training)
output_g = None
outputs = (output_h, output_g)
if output_attentions:
outputs = outputs + (attn_prob, )
return outputs
class TFXLNetFeedForward(layers.Layer):
def __init__(self, config, **kwargs):
super().__init__(**kwargs)
self.layer_norm = tf.keras.layers.LayerNormalization(epsilon=config.layer_norm_eps, name="layer_norm")
self.layer_1 = tf.keras.layers.Dense(
config.d_inner, kernel_initializer=get_initializer(config.initializer_range), name="layer_1"
)
self.layer_2 = tf.keras.layers.Dense(
config.d_model, kernel_initializer=get_initializer(config.initializer_range), name="layer_2"
)
self.dropout = tf.keras.layers.Dropout(config.dropout)
if isinstance(config.ff_activation, str):
self.activation_function = get_tf_activation(config.ff_activation)
else:
self.activation_function = config.ff_activation
def call(self, inp, training=False):
output = inp
output = self.layer_1(output)
output = self.activation_function(output)
output = self.dropout(output, training=training)
output = self.layer_2(output)
output = self.dropout(output, training=training)
output = self.layer_norm(output + inp)
return output
@dataclass
class TFXLNetModelOutput(ModelOutput):
last_hidden_state: tf.Tensor = None
mems: Optional[List[tf.Tensor]] = None
hidden_states: Optional[Tuple[tf.Tensor]] = None
attentions: Optional[Tuple[tf.Tensor]] = None