其实最近在看一些目标检测或目标追踪的文章外。我顺带学习了一些元学习(meta-learning),为少样本学习(few shot learning)领域的一个分支。
学习的目的是考虑到现有的多数深度学习模型需要大量的数据,而有些任务的数据量有限。所以后面会主攻以下两个方向,即
我希望通过写这种知乎文章,来加强我对模型和算法本身的理解,同时也为想学习这两个方向的小白提供一点帮助。
那么话不多说,我们开始对MAML算法进行解析。
MAML,全称呼叫做Model-Agnostic Meta-Learning ,意思就是模型无关的元学习。所以MAML可并不是一个深度学习模型,倒是更像一种训练技巧。
如果你对few-shot learning 或者meta learning的基础知识不懂,那么我并不推荐你去直接看论文,那会让你想放弃对这个领域的学习。
根据我的学习经验,我非常推荐你去看以下李宏毅老师的教学视频。链接如下
“https://www.bilibili.com/video/BV15b411g7Wd?p=57www.bilibili.com
我就是看了两遍视频后,并根据知乎上的文章,加深了对MAML的理解
“知乎徐安言:Model-Agnostic Meta-Learning (MAML) 链接:https://zhuanlan.zhihu.com/p/57864886
这里,我们根据的代码是基于Pytorch的,链接如下:
“https://github.com/dragen1860/MAML-Pytorchgithub.com
MAML 的中文名就是模型无关的元学习。意思就是不论什么深度学习模型,都可以使用MAML来进行少样本学习。论文中提到该方法可以用在分类、回归,甚至强化学习上。
本文我们的代码是基于分类的,那么我们就从分类的角度展开对MAML的解析。
2.1 Meta Learning的一些基础知识
Meta Learning(元学习),也可以称为“learning to learn”。常见的深度学习模型,比如对猫狗的分类模型,使用较多的是卷积神经网络模型,可以是VGG/ResNet等。
那么我们构建好了模型后,学习的就是模型的参数,学习的目的就是使得最终的参数能够在训练集上达到最佳的精度,损失最小。
但是元学习面向的是学习的过程,并不是学习的结果,也就是元学习不需要学出来最终的模型参数,学习的更像是学习技巧这种东西(这就是为什么叫做learning to learn)。
举个例子,人类在进行分类的时候,由于见过太多东西了,且已经学过太多东西的分类了。那么我们可能只需每个物体一张照片,就可以对物体做到很多的区分了,那么人是怎么根据少量的图片就能学习到如此好的成果呢?
本文介绍的MAML,其实是一种固定模型的meta learning ,可能会有人问
模型无关的意思是该方法可以用在CNN,也可以用在RNN,甚至可以用在RL中。但是MAML做的是固定模型的结构,只学习初始化模型参数这件事。
什么意思呢?就是我们希望通过meta-learning学习出一个非常好的模型初始化参数,有了这个初始化参数后,我们只需要少量的样本就可以快速在这个模型中进行收敛。
那么既然是learning to learn,那么输入就不再是单纯的数据了,而是一个个的任务(task)。就像人类在区分物体之前,已经看过了很多中不同物体的区分任务(task),可能是猫狗分类,苹果香蕉分类,男女分类等等,这些都是一个个的任务task。那么MAML的输入是一个个的task,并不是一条条的数据,这与常见的机器学习和深度学习模型是不同的。
2.2 N-way K-shot learning
用于分类任务的MAML,可以理解为一种N-way K-shot learning,这里的N是用于分类的类别数量。K为每个类别的数据量(用于训练)。
什么意思呢?我觉着这篇文章(链接:https://zhuanlan.zhihu.com/p/57864886)解释的就很到位。以下作为引用:
“MAML的论文中多次出现名词task,模型的训练过程都是围绕task展开的,而作者并没有给它下一个明确的定义。要正确地理解task,我们需要了解的相关概念包括,,support set,query set,meta-train classes,meta-test classes等等。是不是有点眼花缭乱?不要着急,举个简单的例子,大家就可以很轻松地掌握这些概念。
“我们假设这样一个场景:我们需要利用MAML训练一个数学模型模型,目的是对未知标签的图片做分类,类别包括(每类5个已标注样本用于训练。另外每类有15个已标注样本用于测试)。我们的训练数据除了中已标注的样本外,还包括另外10个类别的图片(每类30个已标注样本),用于帮助训练元学习模型。我们的实验设置为5-way 5-shot。
“关于具体的训练过程,会在下一节MAML算法详解中介绍。这里我们只需要有一个大概的了解:MAML首先利用的数据集训练元模型,再在的数据集上精调(fine-tune)得到最终的模型。
“此时,即meta-train classes,包含的共计300个样本,即,是用于训练的数据集。与之相对的,即meta-test classes,包含的共计100个样本,即,是用于训练和测试的数据集。
“根据5-way 5-shot的实验设置,我们在训练阶段,从中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task 。其中的5个已标注样本称为的support set,另外15个样本称为的query set。这个task , 就相当于普通深度学习模型训练过程中的一条训练数据。那我们肯定要组成一个batch,才能做随机梯度下降SGD对不对?所以我们反复在训练数据分布中抽取若干个这样的task,组成一个batch。在训练阶段,task、support set、query set的含义与训练阶段均相同
作者的理解很到位,上面我们也说过MAML的数据是一个个的任务,而不是数据。
那么N-way K-shot就是一个个的任务。任务的类别为N,每个类别的Support set为K,至于query set大小需要人为进行选择(上例中选择了15,这是根据 中“每类有15个已标注样本用于测试”决定的)。
2.3 MAML算法流程
MAML中是存在两种梯度下降的,也就是gradient by gradient。第一种梯度下降是每个task都会执行的,而第二种梯度下降只有等batch size个task全部完成第一种梯度下降后才会执行的。
原文中是使用这样的伪代码进行MAML算法描述的。
感觉看起来不是很直观,不妨看我下面的解析。
以上面的5-way 5-shot例子为例,这里我们简单叙述下MAML的算法流程。
至此,MAML的算法流程基本就结束了。
可以看出,每个batch个task中进行batch次第一种梯度下降以及一次第二种梯度下降。
2.4 梯度近似计算
上述算法流程结束后,我们可以获得三个等式,借用李宏毅老师课堂ppt,三个等式如下。
那么我们需要求解
这三个等式的第一个就是最终我们需要求解的等式。那么这个等式中最重要的就是总损失对 的梯度计算,即
然后我们对上述公式里面的梯度进行拆分计算,即对的梯度计算,为
上面是把 拆成了一个个的标量,然后分别计算后再整合。
拆完 后,我们拆 ,如下图所示。
那么根据链式法则,可得
根据上述三个等式中的最后一个,也就是
我们将 和 进行对应,获得以下公式
接着分别公式4中的 i 和 j 进行分析。获得如下等式。
那么作者考虑到这个二次微分不好计算,就假设这个二次微分为0来进行近似计算。如下
那么将公式8的近似结果带入上面的公式4中。那么公式4就可以化简为
这种近似在论文中也体现了出来,如下
“we also include a comparison to dropping this backward pass and using a first-order approximation。
这个 a first-order approximation就是对二次微分的忽略。
那么根据公式9,公式3可以变化为
那么将公式10带入公式2中,就可以简化梯度计算了。
至此有关MAML的解析就结束了。
MAML是meta learning领域非常重要的一种算法。本文主要从原理的角度,结合了一些前人的经验,展开了对MAML的解析,我们发现本文最后的梯度计算中直接忽略二次微分,这样的这样的做法看似比较“鲁莽”,后面将结合下一个meta learning 算法,即Reptile,对MAML这个“鲁莽”行为带来的后果进行分析。
联系客服