连接图像和文本,更多的多模态文章可以看博主整理过的系列(跨界出圈 | 谈谈BERT跨模态预训练),本篇文章主要整理一下OpenAI发表的2篇文章。其中CLIP 能够完成图像与文本类别的匹配,DALL·E 则可以直接基于文本描述生成图像,且性能十分优异。
论文:Learning Transferable Visual Models From Natural Language Supervision 地址:https://arxiv.org/pdf/2103.00020.pdf 代码:https://github.com/openai/CLIP
首先是CLIP,直接看模型吧,分为三步:Contrastive Pretraning,Create dataset classifier from label text和use for zero-shot prediction。
这一步是主要是利用大量的训练数据(直接从网上得到的句子-图像对)得到特征的表示。接下来的两步是测试过程,流程如下图:
同时,基于 CLIP 还可以自由定义自己的分类器!也就是说可以很方便的利用CLIP和很多工作结合,比如等会要整理的 DALL-E 中就用到了 CLIP来提特征。
简单看看CLIP里面的逻辑流程
def forward(self, image, text):
image_features = self.encode_image(image) #编码image
text_features = self.encode_text(text) #编码text
# norm一下特征
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# 计算内积相似度logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image, logits_per_text
论文:Zero-Shot Text-to-Image Generation 地址:https://arxiv.org/pdf/2102.12092.pdf 代码:https://github.com/openai/DALL-E
然后是DALL-E模型,CLIP主要可以做分类检索等任务,而它则可以直接根据文本生成效果非常好的图像。motivation是目标是训练一个transformer进行自动建模,即将文本以及图片的tokens转换为单一的数据流,所以主要是需要考虑如何对2D的图片也转为单数据流。
也是直接看模型,如上图可以分为三个阶段:dVAE,Transformer和CLIP。
值得注意的一些trick:
还有大佬复现code:https://github.com/lucidrains/DALLE-pytorch
这个复现的库可直接调用训练,似乎非常好用,如果你有足够的卡那么pip一下即可:
pip install dalle-pytorch
import torch
from dalle_pytorch import CLIP
clip = CLIP(
dim_text = 512,
dim_image = 512,
dim_latent = 512,
num_text_tokens = 10000,
text_enc_depth = 6,
text_seq_len = 256,
text_heads = 8,
num_visual_tokens = 512,
visual_enc_depth = 6,
visual_image_size = 256,
visual_patch_size = 32,
visual_heads = 8
) #设置CLIP的参数
text = torch.randint(0, 10000, (4, 256))
images = torch.randn(4, 3, 256, 256)
mask = torch.ones_like(text).bool()
loss = clip(text, images, text_mask = mask, return_loss = True) #直接训练CLIP
loss.backward()
对比学习: https://nakaizura.blog.csdn.net/article/details/108941999
[2]Vision Transformer: https://nakaizura.blog.csdn.net/article/details/113095927
- END -
联系客服