系统学习CV-Transformer
- 参考
- 概念
参考
https://www.bilibili.com/video/BV15v411W78M?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20 https://www.bilibili.com/video/BV1pu411o7BE?spm_id_from=333.337.search-card.all.click&vd_source=7155082256127a432d5ed516a6423e20
– 未扒完 https://zh-v2.d2l.ai/chapter_attention-mechanisms/multihead-attention.html
概念
:自主,非自主 https://zh-v2.d2l.ai/chapter_attention-mechanisms/attention-cues.html
非自主性:全连接层和汇聚层
- embedding: 输入通过编码获得向量
- query: 我检查别人,我提供的向量
- key: 我被别人检查,我提供的向量
- value:表达当前词的特征
矩阵是通过训练获得的 x 1 ? W q = q 1 x1 * W^{q}=q_{1} x1?Wq=q1 x 1 ? W k = k 1 x1 * W^{k}=k_{1} x1?Wk=k1 x 1 ∗ W v = v 1 x1 * W^{v}=v_{1} x1∗Wv=v1 x 2 ∗ W q = q 2 x2 * W^{q}=q_{2} x2∗Wq=q2 x 2 ∗ W k = k 2 x2 * W^{k}=k_{2} x2∗Wk=k2 x 2 ∗ W v = v 2 x2 * W^{v}=v_{2} x2∗Wv=v2
< q 1 ⃗ , q 1 ⃗ > <\vec{q_{1}},\vec{q_{1}}> <q1 ,q1 >, < q 1 ⃗ , q 2 ⃗ > <\vec{q_{1}},\vec{q_{2}}> <q1 ,q2 > 得到每个词与其他词的关系()
w 11 = < q 1 ⃗ , q 1 ⃗ > w_{11}=<\vec{q_{1}},\vec{q_{1}}> w11=<q1 ,q1 >, w 12 = < q 1 ⃗ , q 2 ⃗ > w_{12}=<\vec{q_{1}},\vec{q_{2}}> w12=<q1 ,q2 > o u t ( x 1 ) = w 11 ∗ v 1 + w 12 ∗ v 2 out(x_{1})=w_{11}*v_{1}+w_{12}*v_{2} out(x1)=w11∗v1+w12∗v2 再求加权求和进行词的重构(注意权重归一化)
w 21 = < q 2 ⃗ , q 1 ⃗ > w_{21}=<\vec{q_{2}},\vec{q_{1}}> w21=<q2 ,q1 >, w 22 = < q 2 ⃗ , q 2 ⃗ > w_{22}=<\vec{q_{2}},\vec{q_{2}}> w22=<q2 ,q2 > o u t ( x 2 ) = w 21 ∗ v 1 + w 22 ∗ v 2 out(x_{2})=w_{21}*v_{1}+w_{22}*v_{2} out(x2)=w21∗v1+w22∗v2 再求加权求和进行词的重构(注意权重归一化)
–Nadaraya-Watson核回归 https://www.bilibili.com/video/BV1264y1i7R1?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20
将查询与键之间的关系(注意力权重)建模为高斯核函数 最终其实就是softmax核函数
https://www.bilibili.com/video/BV1264y1i7R1?spm_id_from=333.999.0.0&vd_source=7155082256127a432d5ed516a6423e20
上述的内积是一种评分函数 更通用的写法是 a ( q , k i ) a(q,k_{i}) a(q,ki) 两种常用的注意力评分函数:、
query和key长度不同
# https://zh-v2.d2l.ai/chapter_attention-mechanisms/attention-scoring-functions.html
class AdditiveAttention(nn.Module):
def __init__(self,key_size,query_size,num_hiddens,dropout,**kwargs):
super(AdditiveAttention,self).__init__(**kwargs)
self.w_k=nn.Linear(key_size,num_hiddens,bias=False)
self.w_q=nn.Linear(query_size,num_hiddens,bias=False)
self.w_v=nn.Linear(num_hiddens,1,bias=False)
self.dropout=nn.Dropout(dropout)
def forward(self,queries,keys,values,valid_lens):
queries=self.w_q(queries) # (batch_size,查询的个数,num_hiddens)
keys=self.w_k(keys) # (batch_size,键值对的个数,num_hiddens)
# (batch_size,查询的个数,1,num_hiddens)
# (batch_size,1,键值对的个数,num_hiddens)
features=quires,unsqueeze(2)+keys.unsqueeze(1) #
features=torch.tanh(features)
scores=self.w_v(features).squeeze(-1)
# 做softmax
self.attention_weight=masked_softmax(scores,valid_lens)
return torch.bmm(self.dropout(self.attention_weights),values)
query 和 key长度相同 注意:
class DotProductAttention(nn.Module):
def __init___(self,dropout,**kwargs):
super(DotProductAttention,self).__init__(**kwargs)
self.dropout=nn.Dropout(dropout)
def forward(self,queries,keys,values,vaild_lens=None):
d=queries.shape[-1]
scores=torch.bmm(queries,keys.transpose(1,2)/math.sqrt(d))
self.attention_weight=masked_softmax(scores,valid_lens)
return torch.bmm
标签: 5w12v直插二极管sick小型光电传感器w12g