打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
Multi-Head-Attention的作用

导语

      如果有人问你,Multi-Head-Attention的作用是什么?这个八股文一般的问题相信大家也都司空见惯了,《Attention Is All You Need》这篇文章中作者解释的原话是:将隐状态向量分成多个头,形成多个子语义空间,可以让模型去关注不同维度语义空间的信息。不过真的是这样的吗?如果是,这些子语义空间、不同维度的语义空间信息到底指的是什么?

       本文欲对工作、学术中有可能出现的一些Multi-Head-Attention的疑问进行探讨,尽可能的用通俗 的语言和可视化的方法展现出Multi-Head-Attention的内部运作逻辑,涉及问题点:

  1.        如何理解Self-Attention?Attention矩阵怎么读?为什么要scale?为什么要用Self-Attention?(基础知识铺垫)

  2.        Multi-Head-Attention的作用到底是什么?(本文的核心讨论点)

1.Self-Attention

1.1 Attention矩阵应该怎么读?

一些文章里漫天飞舞的Attention矩阵的热力图,看的懵懵的,到底应该怎么理解?比如'王者荣耀怎么玩’这句话某一头的Attention矩阵热力图如下所示,该怎么理解? 

Word2Vec

以ELMo为首的双向LSTM、RNN类语言模型,LSTM拥有长短时记忆门和循环输入序列计算的机制,可以关注到第i个单词前的i-1个单词的信息,对文本进行前向和后向的建模后,将前、后向得到的隐状态embedding拼接起来就可以得到第i个单词的表达,这其实也做到了前后向信息的交互。

多头注意力图解

4条query的第1层attention pattern

对同样的4条query最后一层的多头attention热力图也做了一个输出,如下:

4条query的第4层attention pattern

4条query的12个头的attention pattern

可以看到,第1~6、8、10~12head的pattern在破坏了语法/句法/词法前后都几乎没有任何改变:

4条query的1~6头的attention pattern

只有7、9发生了明显预期的语法/词法层面的注意力改变。词法层面,可以看到拆开了“怎么”、“王者”这样的常见词组,第7、9head词语之间的注意力依然保持原来的相对水准没太大变化:

4条query的第7、9头的attention pattern

句法层面,对于'王者荣耀怎么玩', '怎么玩王荣者耀'这两句互为倒装句的query:

可以看到第7个head,这两句倒装句的倒装主体之内的注意力是差不多的,即红框和绿框中的pattern在倒装后依然后相似的pattern,但是第9个head在倒装后上下黄框和绿框之间的pattern就发生了比较明显的变化。

笔者大量采样query,破坏语法/词法结构后观察多头pattern热力图,发现这个规律非常明显,因此可以说明该BERT底层第7个head记录下了语法和词法层面的pattern,第9个head只有词法层面的pattern。

那对于第1、3、4、5、11、12这些头就可以直接将其剪枝删除掉或者只保留一个吗?很遗憾,笔者目前也没办法给出一个结论,笔者更倾向于它们其实都有作用,只能让下游任务模型自己去甄别这些pattern。目前已经有论文证明在机器翻译等场景下剪枝掉某些head效果会更好:论文地址。

2.2 可疑的大多数

对于大多数没怎么发生变化的pattern,实在太奇怪了。笔者怀疑大部分head只是学习到了position embedding主导的相关信息,即对于任何的query,只要是这个位置,attention pattern是差不多的,果然当笔者把position embedding去掉后,记录了词法、语法层面的第7、9head的pattern几乎没怎么变,而那些之前大部分一成不变的pattern,除了第5、12head外都发生了剧烈的变化,有的甚至开始凸显出了词法信息和语法信息(比如4、6head):

删除position embedding后的attention pattern

2.3 多头注意力初步的结论

这里可以得到一个结论,多头attention的有些头的功能是不一样的,有的头可能没啥信息(如第5head),有的头pattern由位置信息主导,有的头由语法信息主导,有的头由词法信息主导,而能够捕捉到语法/句法/词法信息的头其实是非常少的(这一点已被大量学术论文证明,笔者的句法破坏实验也验证了这一点),但是为了保证这些pattern能够被抽取出来,需要让其有一定的头的基数,因为单头很容易就变成自己注意力全在自己身上了,这一点也可以从'以上pattern中大部分pattern都是自己关注自己’这个现象身上得到佐证。假如将12头强行变成2头,注意力pattern就会变成如下这种自己注意自己的pattern,因为12个头大部分pattern都是自己关注自己,向量拼接起来后token间的向量相似度会被大部分分量所裹挟,因此也变成自己关注自己。

将12个头变成2个头的attention矩阵热力图

从1.4和1.5我们可以知道,bert本身是不期望self-attention收敛到某一种固定pattern上的,它期望可以捕捉到一些多样化的pattern,bert这类语言模型,token的注意力很容易只看到自己,出现只关注自己的pattern,如下图所示:

只注意到自己的单一注意力pattern

而我们更期望还能捕捉出这类多样的pattern出来:

多样的注意力pattern

注意!这里并不是说这种pattern就一定好,也不是说那种位置信息主导的pattern就一定不好,只是说我们期望捕捉更多的模式,从而利于下游多样的任务微调时,一旦这类pattern有用就可以激活出来让下游任务可以学习到,所以Transformer的角色定位是特征抽取器。所以多头对一个向量切分不同的维度来捕捉不同的pattern,这里就可以解释论文里原话中的不同维度的语义信息。

从另一个方面来解释,多头的核心思想就是ensemble,如随机森林一样,将特征切分,每个head就像是一个弱分类器,让最后得到的embedding关注多方面信息,不要过拟合到某一种pattern上,这一点上面的实验图像可以很清晰的看出来。

2.4 多多益善?

那这样来说,头是不是越多越好呢?答案是否定的,已有论文证明,头数不是越多越好,论文实验结果如下:

头数实验

可以看到8/16个头时,PPL/BLEU最好,4/32个头次之,1个头最差。头确实不是越多越好,头太多了,每个qkv分到的维度就会降低,表达能力也就变差,也未必能更好的捕捉到语法/句法/词法信息。

2.5 其它层也有话要说!

为了进一步验证BERT多层、多头之间的变化,笔者对10w条长度相同的query绘制了1~4层的多头Attention矩阵热力图,将其多头Attention矩阵求平均后绘制出如下热力图,从上到下依次是1~4层,从左到右依次是1~12个head: 

大量query下4层的attention pattern均值

这种大量数据矩阵平均的情况,如果某一头的均值attention还能保持一个明显的pattern,那就说明对于任何一个query,该层该头的pattern是差不多的,因此该头的pattern是以位置信息为主导的,几乎不包含语义信息,因为不论啥query,在这个位置的pattern都一个样。反之如果该层该头热力图出现大片同色系区域,则可以一定程度上说明该层该头有可能记录了大量的语义、语法、词法信息,因为这些信息有可能出现在该query的任意位置,因此每个位置平均值大概率就是趋于相同。

可以看到,对于第1、3、4、5、11、12这些头基本pattern是很固定的,而且随着层数的增加,pattern越来越固定,而2、7、8、9、10的pattern在某几层或者全部层都比较多样化(接近于同色系),这也可以说明,为什么需要多层的Transformer堆叠,因为有些信息可能在某一层之中无法捕捉到,需要在其它层捕捉。

最后,基本所有head在最后一层都变得很固定(颜色很深,pattern很明显),随着层数的加深,头和头之间的差异越来越小。

已有论文对层和多头pattern之间的关系做了探讨:论文地址,这篇文章对12层的BERT绘制了层数和头pattern分布的关系,如下图所示,不同的颜色代表不同的层,同一颜色的分布代表了同一层的头差距。比如:

  • (1)我们可以先看看第一层,也就是黑色。在上下左右都有黑色的点出现,分布是比较分散稀疏的。

  • (2)再看看第六层浅紫色的点,相对来说分布比较集中了。

  • (3)再看看第十二层,深红色,基本全部集中在下方,分布非常集中。

由此可以得到这边论文的结论:头之间的差距随着所在层数变大而减少。换句话说,头之间的方差随着所在层数的增大而减小。笔者这里的观察实验也论证了这一点,随着层数的增加,pattern在慢慢固定且趋同。

Attention pattern和层数之间的分布关系

3. 总结

在此,笔者可以得到一些对Multi-Head-Attention的结论:

  • (1)对于大部分query,每个头都学习了某种固定的pattern模式,而且12个头中大部分pattern是差不多的,但是总有少数的pattern才能捕捉到语法/句法/词法信息。

  • (2)越靠近底层的attention,其pattern种类越丰富,关注到的点越多,越到顶层的attention,各个head的pattern趋同。

  • (3)head数越少,pattern会更倾向于token关注自己本身(或者其他的比较单一的模式,比如都关注CLS)。

  • (4)多头的核心思想应该就是ensemble,如随机森林一样,将特征切分,每个head就像是一个弱分类器,让最后得到的embedding关注多方面信息,不要过拟合到某一种pattern上。

  • (5)已有论文证明head数目不是越多越好,bert-base上实验的结果为8、16最好,太多太少都会变差。

  • (6)multi-head-attention中大部分头没有捕捉到语法/句法信息,但是笔者这里没办法做出断言说它们是没有用的,具体还是要看下游任务对其的适配程度。个人倾向于大部分pattern只是不符合人类的语法,在不同的下游任务中应该还是有用武之地的。

———————————————————————————————


【AI科研看爱加】也欢迎各领域科研爱好者或AI爱好者们加入到爱加大家庭中来,共同运营与维护这个公益性AI公众号,让我们站在各研究领域、学习领域或者工作领域,共同讲好AI故事、传播AI声音,受益于更多对AI有兴趣的伙伴们.

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
Transform模型原理
再谈attention机制
抛开卷积,多头自注意力能够表达任何卷积操作
旷视孙剑团队提出Anchor DETR:基于Transformer的目标检测新网络
Transformer
Lucene学习总结之二:Lucene的总体架构
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服