打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Transformer语言模型原理解读
userphoto

2023.02.25 湖北

关注

一、简介

基于假设:一个词在句子中的意思,与上下文(语境)有关。与哪些词有关呢?Transformer就是:利用点积将句子中所有词的影响当成权重都考虑了进去。

Transform模型是与RNN和CNN都完全不同的思路。相比Transformer,RNN/CNN的问题:

  1. RNN序列化处理效率提不上去。理论上,RNN效果上问题不大。
  2. CNN感受野小。CNN只考虑卷积核大小区域,核内参数共享,并行/计算效率不是问题,但受限于核的大小,不能考虑整个上下文。

在并行方面,多头attention和CNN一样不依赖于前一时刻的计算,可以很好的并行,优于RNN。在长距离依赖上,由于self-attention是每个词和所有词都要计算attention,所以不管他们中间有多长距离,最大的路径长度也都只是1。可以捕获长距离依赖关系。

二、注意力机制

注意力实际就是加权

2.1 NLP中的注意力

以RNN做机器翻译为例,下两图[1]分别是有没有注意力:


没有注意力机制的机器翻译,翻译下一词时,只考虑源语言经过网络后最终的表达(编码/向量);而注意力机制是要考虑源语言中每(多)个词的表达(编码/向量)。

NLP中有个非常常见的一个三元组概念:Query、Key、Value,其中绝大部分情况Key=Value。在机器翻译中,Query是已经翻译出来的部分,Key和Value是源语言中每个词的表达(编码/向量),没有注意力时直接拿Query就去预测下一个词,注意力机制的计算就是用Query和Key计算出一组权重,赋权到Value上,拿Value去预测下一词。

翻译编码解码模型[2]

计算权重[2]

加权[2]

2.2 自注意力

自注意力模型就是Query“=”Key“=”Value,挖掘一个句子内部的联系。计算句子中每个字之间的互相影响/权重,再加权到句子中每个字的向量上。这个计算就是用了点积。

Query、Key、Value都来自同一个输入,但是经过3个不同线性映射(全连接层)得到,所以未必完全相等。


公式中QKT是Query向量和Key向量做点积,为了防止点积结果数值过大,做了一个放缩(dk是Key向量的长度),结果再经过一个softmax归一化成一个和为1的权重,乘到Value向量上。

attention可视化的效果(这里不同颜色代表attention不同头的结果,颜色越深attention值越大)。可以看到self-attention在这里可以学习到句子内部长距离依赖'making…….more difficult'这个短语。

2.2.1 点积(Dot-Product)

  • 两向量点积表示两个向量的相似度。
  • 点积还有一个重要的特点是没有参数。

点积也叫点乘,一维点积用几何表示是: ab=|a||b|cosθ 。与我们常用的余弦相识度/夹角作用一样,与两向量的相似程度成正比。

2.2.2 具体计算过程:

假设我们句子长度设为512,每个单词embedding成256维。

  1. QKTQuery与Key点积。

Pytorch代码:

attn = torch.bmm(q, k.transpose(1, 2))
  1. scale放缩、softmax归一化、dropout随机失活/置零
    Pytorch代码:
attn = attn / self.temperature
if mask is not None:
    attn = attn.masked_fill(mask, -np.inf)
attn = self.softmax(attn)
attn = self.dropout(attn)
  1. 将权重矩阵加权到Value上,维度未变化。

Pytorch代码:

output = torch.bmm(attn, v)

2.3 多头注意力

并不是将长度是512的句子整个做点积自注意力,而是将其“拆”成h份,没份长度为512/h,然后每份单独去加权注意力再拼接到一起,Q、K、V分别拆分。

“拆”的过程是一个独立的(different)、可学习的(learned)线性映射。实际实现可以是h个全连接层,每个全连接层输入维度是512,输出512/h;也可以用一个全连接,输入输出均为512,输出之后再切成h份。

多头能够从不同的表示子空间里学习相关信息。

在两个头和单头的比较中,可以看到单头'its'这个词只能学习到'law'的依赖关系,而两个头'its'不仅学习到了'law'还学习到了'application'依赖关系。

Pytorch实现:

    def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
        self.w_qs = nn.Linear(d_model, n_head * d_k)
        self.w_ks = nn.Linear(d_model, n_head * d_k)
        self.w_vs = nn.Linear(d_model, n_head * d_v)
        nn.init.normal_(self.w_qs.weight, mean=0, std=np.sqrt(2.0 / (d_model   d_k)))
        nn.init.normal_(self.w_ks.weight, mean=0, std=np.sqrt(2.0 / (d_model   d_k)))
        nn.init.normal_(self.w_vs.weight, mean=0, std=np.sqrt(2.0 / (d_model   d_v)))  
        ...
    def forward(self, q, k, v, mask=None):

        sz_b, len_q, _ = q.size()
        sz_b, len_k, _ = k.size()
        sz_b, len_v, _ = v.size()

        q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
        k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
        v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)

        q = q.permute(2, 0, 1, 3).contiguous().view(-1, len_q, d_k) # (n*b) x lq x dk
        k = k.permute(2, 0, 1, 3).contiguous().view(-1, len_k, d_k) # (n*b) x lk x dk
        v = v.permute(2, 0, 1, 3).contiguous().view(-1, len_v, d_v) # (n*b) x lv x dv

三、位置编码(Positional Encoding)

因为transformer没有RNN和CNN,为了考虑位置信息,论文中直接将全局位置编号加到Embedding向量每个维度上。
Pytorch代码:

        # -- Forward
        enc_output = self.src_word_emb(src_seq)   self.position_enc(src_pos)

另外,论文中位置编码还利用了sin/cos正余弦函数考虑周期性和归一化。

四、残差和前馈(Feed Forward)

4.1 为什么残差[3]

网络的深度为什么重要?

因为CNN能够提取low/mid/high-level的特征,网络的层数越多,意味着能够提取到不同level的特征越丰富。并且,越深的网络提取的特征越抽象,越具有语义信息。

为什么不能简单地增加网络层数?

对于原来的网络,如果简单地增加深度,会导致梯度弥散或梯度爆炸。

对于该问题的解决方法是正则化初始化和中间的正则化层(Batch Normalization),这样的话可以训练几十层的网络。

虽然通过上述方法能够训练了,但是又会出现另一个问题,就是退化问题,网络层数增加,但是在训练集上的准确率却饱和甚至下降了。这个不能解释为overfitting,因为overfit应该表现为在训练集上表现更好才对。
退化问题说明了深度网络不能很简单地被很好地优化。
作者通过实验:通过浅层网络 y=x 等同映射构造深层模型,结果深层模型并没有比浅层网络有等同或更低的错误率,推断退化问题可能是因为深层的网络并不是那么好训练,也就是求解器很难去利用多层网络拟合同等函数。

怎么解决退化问题?

深度残差网络。如果深层网络的后面那些层是恒等映射,那么模型就退化为一个浅层网络。那现在要解决的就是学习恒等映射函数了。 但是直接让一些层去拟合一个潜在的恒等映射函数H(x) = x,比较困难,这可能就是深层网络难以训练的原因。但是,如果把网络设计为H(x) = F(x) x,如下图。我们可以转换为学习一个残差函数F(x) = H(x) - x. 只要F(x)=0,就构成了一个恒等映射H(x) = x. 而且,拟合残差肯定更加容易。

4.2 前馈

每个attention模块后面会跟两个全连接,中间加了一个Relu激活函数,公式表示:

也可用两个核为1的CNN层代替。
两个全连接是512->2048->512的操作。原因未详细介绍。

五、训练-模型的参数在哪里

transformer的核心点积是没有参数,transform结构的训练,会优化的参数主要在:

  1. 嵌入层-Word Embedding
  2. 前馈(Feed Forward)层
  3. 多头注意力中的“切片”操作(映射成多个/头小向量)实际是一个全连接层(线性映射矩阵),以及多头输出拼接结果(Concat)后会经过一个Linear全连接层。这两个全连接层也是残差块有意义的地方,如果没有这一层,那这个注意力机制中就没有参数,残差就没有意义了。

六、参考文献

[1]. Neural Machine Translation by Jointly Learning to Align and Translate
[2].
[3]. (残差的解读)[https://www.cnblogs.com/alanma/p/6877166.html]

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
seq2seq翻译模型里的attention model(注意力模型)
如何从零开始用PyTorch实现Chatbot?(附完整代码)
##学习 | 超详细逐步图解 Transformer
以自注意力机制破局Transformer
Transformer算法完全解读
深度学习中的注意力机制
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服