如果有人问你,Multi-Head-Attention的作用是什么?这个八股文一般的问题相信大家也都司空见惯了,《Attention Is All You Need》这篇文章中作者解释的原话是:将隐状态向量分成多个头,形成多个子语义空间,可以让模型去关注不同维度语义空间的信息。不过真的是这样的吗?如果是,这些子语义空间、不同维度的语义空间信息到底指的是什么?
本文欲对工作、学术中有可能出现的一些Multi-Head-Attention的疑问进行探讨,尽可能的用通俗 的语言和可视化的方法展现出Multi-Head-Attention的内部运作逻辑,涉及问题点:
如何理解Self-Attention?Attention矩阵怎么读?为什么要scale?为什么要用Self-Attention?(基础知识铺垫)
Multi-Head-Attention的作用到底是什么?(本文的核心讨论点)
一些文章里漫天飞舞的Attention矩阵的热力图,看的懵懵的,到底应该怎么理解?比如'王者荣耀怎么玩’这句话某一头的Attention矩阵热力图如下所示,该怎么理解?
以ELMo为首的双向LSTM、RNN类语言模型,LSTM拥有长短时记忆门和循环输入序列计算的机制,可以关注到第i个单词前的i-1个单词的信息,对文本进行前向和后向的建模后,将前、后向得到的隐状态embedding拼接起来就可以得到第i个单词的表达,这其实也做到了前后向信息的交互。
对同样的4条query最后一层的多头attention热力图也做了一个输出,如下:
可以看到,第1~6、8、10~12head的pattern在破坏了语法/句法/词法前后都几乎没有任何改变:
只有7、9发生了明显预期的语法/词法层面的注意力改变。词法层面,可以看到拆开了“怎么”、“王者”这样的常见词组,第7、9head词语之间的注意力依然保持原来的相对水准没太大变化:
句法层面,对于'王者荣耀怎么玩', '怎么玩王荣者耀'这两句互为倒装句的query:
可以看到第7个head,这两句倒装句的倒装主体之内的注意力是差不多的,即红框和绿框中的pattern在倒装后依然后相似的pattern,但是第9个head在倒装后上下黄框和绿框之间的pattern就发生了比较明显的变化。
笔者大量采样query,破坏语法/词法结构后观察多头pattern热力图,发现这个规律非常明显,因此可以说明该BERT底层第7个head记录下了语法和词法层面的pattern,第9个head只有词法层面的pattern。
那对于第1、3、4、5、11、12这些头就可以直接将其剪枝删除掉或者只保留一个吗?很遗憾,笔者目前也没办法给出一个结论,笔者更倾向于它们其实都有作用,只能让下游任务模型自己去甄别这些pattern。目前已经有论文证明在机器翻译等场景下剪枝掉某些head效果会更好:论文地址。
对于大多数没怎么发生变化的pattern,实在太奇怪了。笔者怀疑大部分head只是学习到了position embedding主导的相关信息,即对于任何的query,只要是这个位置,attention pattern是差不多的,果然当笔者把position embedding去掉后,记录了词法、语法层面的第7、9head的pattern几乎没怎么变,而那些之前大部分一成不变的pattern,除了第5、12head外都发生了剧烈的变化,有的甚至开始凸显出了词法信息和语法信息(比如4、6head):
这里可以得到一个结论,多头attention的有些头的功能是不一样的,有的头可能没啥信息(如第5head),有的头pattern由位置信息主导,有的头由语法信息主导,有的头由词法信息主导,而能够捕捉到语法/句法/词法信息的头其实是非常少的(这一点已被大量学术论文证明,笔者的句法破坏实验也验证了这一点),但是为了保证这些pattern能够被抽取出来,需要让其有一定的头的基数,因为单头很容易就变成自己注意力全在自己身上了,这一点也可以从'以上pattern中大部分pattern都是自己关注自己’这个现象身上得到佐证。假如将12头强行变成2头,注意力pattern就会变成如下这种自己注意自己的pattern,因为12个头大部分pattern都是自己关注自己,向量拼接起来后token间的向量相似度会被大部分分量所裹挟,因此也变成自己关注自己。
从1.4和1.5我们可以知道,bert本身是不期望self-attention收敛到某一种固定pattern上的,它期望可以捕捉到一些多样化的pattern,bert这类语言模型,token的注意力很容易只看到自己,出现只关注自己的pattern,如下图所示:
而我们更期望还能捕捉出这类多样的pattern出来:
注意!这里并不是说这种pattern就一定好,也不是说那种位置信息主导的pattern就一定不好,只是说我们期望捕捉更多的模式,从而利于下游多样的任务微调时,一旦这类pattern有用就可以激活出来让下游任务可以学习到,所以Transformer的角色定位是特征抽取器。所以多头对一个向量切分不同的维度来捕捉不同的pattern,这里就可以解释论文里原话中的不同维度的语义信息。
从另一个方面来解释,多头的核心思想就是ensemble,如随机森林一样,将特征切分,每个head就像是一个弱分类器,让最后得到的embedding关注多方面信息,不要过拟合到某一种pattern上,这一点上面的实验图像可以很清晰的看出来。
那这样来说,头是不是越多越好呢?答案是否定的,已有论文证明,头数不是越多越好,论文实验结果如下:
可以看到8/16个头时,PPL/BLEU最好,4/32个头次之,1个头最差。头确实不是越多越好,头太多了,每个qkv分到的维度就会降低,表达能力也就变差,也未必能更好的捕捉到语法/句法/词法信息。
为了进一步验证BERT多层、多头之间的变化,笔者对10w条长度相同的query绘制了1~4层的多头Attention矩阵热力图,将其多头Attention矩阵求平均后绘制出如下热力图,从上到下依次是1~4层,从左到右依次是1~12个head:
这种大量数据矩阵平均的情况,如果某一头的均值attention还能保持一个明显的pattern,那就说明对于任何一个query,该层该头的pattern是差不多的,因此该头的pattern是以位置信息为主导的,几乎不包含语义信息,因为不论啥query,在这个位置的pattern都一个样。反之如果该层该头热力图出现大片同色系区域,则可以一定程度上说明该层该头有可能记录了大量的语义、语法、词法信息,因为这些信息有可能出现在该query的任意位置,因此每个位置平均值大概率就是趋于相同。
可以看到,对于第1、3、4、5、11、12这些头基本pattern是很固定的,而且随着层数的增加,pattern越来越固定,而2、7、8、9、10的pattern在某几层或者全部层都比较多样化(接近于同色系),这也可以说明,为什么需要多层的Transformer堆叠,因为有些信息可能在某一层之中无法捕捉到,需要在其它层捕捉。
最后,基本所有head在最后一层都变得很固定(颜色很深,pattern很明显),随着层数的加深,头和头之间的差异越来越小。
已有论文对层和多头pattern之间的关系做了探讨:论文地址,这篇文章对12层的BERT绘制了层数和头pattern分布的关系,如下图所示,不同的颜色代表不同的层,同一颜色的分布代表了同一层的头差距。比如:
(1)我们可以先看看第一层,也就是黑色。在上下左右都有黑色的点出现,分布是比较分散稀疏的。
(2)再看看第六层浅紫色的点,相对来说分布比较集中了。
(3)再看看第十二层,深红色,基本全部集中在下方,分布非常集中。
由此可以得到这边论文的结论:头之间的差距随着所在层数变大而减少。换句话说,头之间的方差随着所在层数的增大而减小。笔者这里的观察实验也论证了这一点,随着层数的增加,pattern在慢慢固定且趋同。
在此,笔者可以得到一些对Multi-Head-Attention的结论:
(1)对于大部分query,每个头都学习了某种固定的pattern模式,而且12个头中大部分pattern是差不多的,但是总有少数的pattern才能捕捉到语法/句法/词法信息。
(2)越靠近底层的attention,其pattern种类越丰富,关注到的点越多,越到顶层的attention,各个head的pattern趋同。
(3)head数越少,pattern会更倾向于token关注自己本身(或者其他的比较单一的模式,比如都关注CLS)。
(4)多头的核心思想应该就是ensemble,如随机森林一样,将特征切分,每个head就像是一个弱分类器,让最后得到的embedding关注多方面信息,不要过拟合到某一种pattern上。
(5)已有论文证明head数目不是越多越好,bert-base上实验的结果为8、16最好,太多太少都会变差。
(6)multi-head-attention中大部分头没有捕捉到语法/句法信息,但是笔者这里没办法做出断言说它们是没有用的,具体还是要看下游任务对其的适配程度。个人倾向于大部分pattern只是不符合人类的语法,在不同的下游任务中应该还是有用武之地的。
———————————————————————————————
【AI科研看爱加】也欢迎各领域科研爱好者或AI爱好者们加入到爱加大家庭中来,共同运营与维护这个公益性AI公众号,让我们站在各研究领域、学习领域或者工作领域,共同讲好AI故事、传播AI声音,受益于更多对AI有兴趣的伙伴们.
联系客服