本文将从Transformer的本质、Transformer的原理、Transformer架构改进三个方面,带您一文搞懂Transformer。Transformer架构:主要由输入部分(输入输出嵌入与位置编码)、多层编码器、多层解码器以及输出部分(输出线性层与Softmax)四大部分组成。
Transformer架构
输入部分:
源文本嵌入层:将源文本中的词汇数字表示转换为向量表示,捕捉词汇间的关系。
位置编码器:为输入序列的每个位置生成位置向量,以便模型能够理解序列中的位置信息。
目标文本嵌入层(在解码器中使用):将目标文本中的词汇数字表示转换为向量表示。
编码器部分:
由N个编码器层堆叠而成。
每个编码器层由两个子层连接结构组成:第一个子层是一个多头自注意力子层,第二个子层是一个前馈全连接子层。每个子层后都接有一个规范化层和一个残差连接。
解码器部分:
由N个解码器层堆叠而成。
每个解码器层由三个子层连接结构组成:第一个子层是一个带掩码的多头自注意力子层,第二个子层是一个多头注意力子层(编码器到解码器),第三个子层是一个前馈全连接子层。每个子层后都接有一个规范化层和一个残差连接。
输出部分:
线性层:将解码器输出的向量转换为最终的输出维度。
Softmax层:将线性层的输出转换为概率分布,以便进行最终的预测。
Encoder-Decoder(编码器-解码器):左边是N个编码器,右边是N个解码器,Transformer中的N为6。
Encoder-Decoder(编码器-解码器)
Encoder(编码器)架构
Decoder(解码器)架构
Transformer工作原理
Multi-Head Attention(多头注意力):它允许模型同时关注来自不同位置的信息。通过分割原始的输入向量到多个头(head),每个头都能独立地学习不同的注意力权重,从而增强模型对输入序列中不同部分的关注能力。
Multi-Head Attention(多头注意力)
输入线性变换:对于输入的Query(查询)、Key(键)和Value(值)向量,首先通过线性变换将它们映射到不同的子空间。这些线性变换的参数是模型需要学习的。
分割多头:经过线性变换后,Query、Key和Value向量被分割成多个头。每个头都会独立地进行注意力计算。
缩放点积注意力:在每个头内部,使用缩放点积注意力来计算Query和Key之间的注意力分数。这个分数决定了在生成输出时,模型应该关注Value向量的部分。
注意力权重应用:将计算出的注意力权重应用于Value向量,得到加权的中间输出。这个过程可以理解为根据注意力权重对输入信息进行筛选和聚焦。
拼接和线性变换:将所有头的加权输出拼接在一起,然后通过一个线性变换得到最终的Multi-Head Attention输出。
Scaled Dot-Product Attention(缩放点积注意力):它是Transformer模型中多头注意力机制的一个关键组成部分。
Scaled Dot-Product Attention(缩放点积注意力)
Query、Key和Value矩阵:
Query矩阵(Q):表示当前的关注点或信息需求,用于与Key矩阵进行匹配。
Key矩阵(K):包含输入序列中各个位置的标识信息,用于被Query矩阵查询匹配。
Value矩阵(V):存储了与Key矩阵相对应的实际值或信息内容,当Query与某个Key匹配时,相应的Value将被用来计算输出。
点积计算:
缩放因子:
Softmax函数:
加权求和:
BERT:BERT是一种基于Transformer的预训练语言模型,它的最大创新之处在于引入了双向Transformer编码器,这使得模型可以同时考虑输入序列的前后上下文信息。
BERT架构
输入层(Embedding):
Token Embeddings:将单词或子词转换为固定维度的向量。
Segment Embeddings:用于区分句子对中的不同句子。
Position Embeddings:由于Transformer模型本身不具备处理序列顺序的能力,所以需要加入位置嵌入来提供序列中单词的位置信息。
编码层(Transformer Encoder):BERT模型使用双向Transformer编码器进行编码。
输出层(Pre-trained Task-specific Layers):
GPT:GPT也是一种基于Transformer的预训练语言模型,它的最大创新之处在于使用了单向Transformer编码器,这使得模型可以更好地捕捉输入序列的上下文信息。
GPT架构
输入层(Input Embedding):
编码层(Transformer Encoder):GPT模型使用单向Transformer编码器进行编码和生成。
输出层(Output Linear and Softmax):
本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请
点击举报。