打开APP
userphoto
未登录

开通VIP,畅享免费电子书等14项超值服

开通VIP
让机器学习“如何学习”!从零开始读懂MAML!

1. 前言

其实最近在看一些目标检测目标追踪的文章外。我顺带学习了一些元学习(meta-learning),为少样本学习(few shot learning)领域的一个分支。

学习的目的是考虑到现有的多数深度学习模型需要大量的数据,而有些任务的数据量有限。所以后面会主攻以下两个方向,即

  • 目标追踪,单目标的已经解析的差不多了,后面会着重解析多目标追踪领域的一些模型
  • 元学习,也可以理解为少样本学习,尝试在一些模型中应用元学习的技巧

我希望通过写这种知乎文章,来加强我对模型和算法本身的理解,同时也为想学习这两个方向的小白提供一点帮助。

那么话不多说,我们开始对MAML算法进行解析。

MAML,全称呼叫做Model-Agnostic Meta-Learning ,意思就是模型无关的元学习。所以MAML可并不是一个深度学习模型,倒是更像一种训练技巧

如果你对few-shot learning 或者meta learning的基础知识不懂,那么我并不推荐你去直接看论文,那会让你想放弃对这个领域的学习。

根据我的学习经验,我非常推荐你去看以下李宏毅老师的教学视频。链接如下

https://www.bilibili.com/video/BV15b411g7Wd?p=57www.bilibili.com

我就是看了两遍视频后,并根据知乎上的文章,加深了对MAML的理解

知乎徐安言:Model-Agnostic Meta-Learning (MAML) 链接:https://zhuanlan.zhihu.com/p/57864886

这里,我们根据的代码是基于Pytorch的,链接如下:

https://github.com/dragen1860/MAML-Pytorchgithub.com

2. 简单谈谈MAML

MAML 的中文名就是模型无关的元学习。意思就是不论什么深度学习模型,都可以使用MAML来进行少样本学习。论文中提到该方法可以用在分类回归,甚至强化学习上。

本文我们的代码是基于分类的,那么我们就从分类的角度展开对MAML的解析。

2.1 Meta Learning的一些基础知识

Meta Learning(元学习),也可以称为“learning to learn”。常见的深度学习模型,比如对猫狗的分类模型,使用较多的是卷积神经网络模型,可以是VGG/ResNet等。

那么我们构建好了模型后,学习的就是模型的参数,学习的目的就是使得最终的参数能够在训练集上达到最佳的精度损失最小

但是元学习面向的是学习的过程,并不是学习的结果,也就是元学习不需要学出来最终的模型参数,学习的更像是学习技巧这种东西(这就是为什么叫做learning to learn)。

举个例子,人类在进行分类的时候,由于见过太多东西了,且已经学过太多东西的分类了。那么我们可能只需每个物体一张照片,就可以对物体做到很多的区分了,那么人是怎么根据少量的图片就能学习到如此好的成果呢?

  • 显然 ,我们已经掌握了各种用于图片分类的较巧了,比如根据物体的轮廓纹理等信息进行分类,那么根据轮廓根据纹理等区分物体的方法,就是我们在meta learning中需要教机器进行学习的学习技巧

本文介绍的MAML,其实是一种固定模型的meta learning ,可能会有人问

  • 不是说MAML是模型无关的吗?为什么需要固定模型?

模型无关的意思是该方法可以用在CNN,也可以用在RNN,甚至可以用在RL中。但是MAML做的是固定模型的结构,只学习初始化模型参数这件事。

什么意思呢?就是我们希望通过meta-learning学习出一个非常好模型初始化参数,有了这个初始化参数后,我们只需要少量的样本就可以快速在这个模型中进行收敛。

那么既然是learning to learn,那么输入就不再是单纯的数据了,而是一个个的任务(task)。就像人类在区分物体之前,已经看过了很多中不同物体的区分任务(task),可能是猫狗分类,苹果香蕉分类,男女分类等等,这些都是一个个的任务task。那么MAML的输入是一个个的task,并不是一条条的数据,这与常见的机器学习和深度学习模型是不同的。

2.2  N-way  K-shot learning

用于分类任务的MAML,可以理解为一种N-way  K-shot learning,这里的N是用于分类的类别数量。K为每个类别的数据量(用于训练)。

什么意思呢?我觉着这篇文章(链接:https://zhuanlan.zhihu.com/p/57864886)解释的就很到位。以下作为引用:

MAML的论文中多次出现名词task,模型的训练过程都是围绕task展开的,而作者并没有给它下一个明确的定义。要正确地理解task,我们需要了解的相关概念包括support set,query set,meta-train classes,meta-test classes等等。是不是有点眼花缭乱?不要着急,举个简单的例子,大家就可以很轻松地掌握这些概念。

我们假设这样一个场景:我们需要利用MAML训练一个数学模型模型,目的是对未知标签的图片做分类,类别包括(每类5个已标注样本用于训练。另外每类有15个已标注样本用于测试)。我们的训练数据除了中已标注的样本外,还包括另外10个类别的图片(每类30个已标注样本),用于帮助训练元学习模型。我们的实验设置为5-way 5-shot

关于具体的训练过程,会在下一节MAML算法详解中介绍。这里我们只需要有一个大概的了解:MAML首先利用的数据集训练元模型,再在的数据集上精调(fine-tune)得到最终的模型

此时,meta-train classes包含的共计300个样本,即,是用于训练的数据集。与之相对的,meta-test classes包含的共计100个样本,即,是用于训练和测试的数据集。

根据5-way 5-shot的实验设置,我们在训练阶段,从中随机取5个类别,每个类别再随机取20个已标注样本,组成一个task 。其中的5个已标注样本称为support set,另外15个样本称为query set。这个task , 就相当于普通深度学习模型训练过程中的一条训练数据。那我们肯定要组成一个batch,才能做随机梯度下降SGD对不对?所以我们反复在训练数据分布中抽取若干个这样的task,组成一个batch。在训练阶段,tasksupport setquery set的含义与训练阶段均相同

作者的理解很到位,上面我们也说过MAML的数据是一个个的任务,而不是数据

那么N-way K-shot就是一个个的任务。任务的类别为N,每个类别的Support set为K,至于query set大小需要人为进行选择(上例中选择了15,这是根据 中“每类有15个已标注样本用于测试”决定的)。

2.3 MAML算法流程

MAML中是存在两种梯度下降的,也就是gradient  by gradient第一种梯度下降是每个task都会执行的,而第二种梯度下降只有等batch size个task全部完成第一种梯度下降后才会执行的。

原文中是使用这样的伪代码进行MAML算法描述的。

img

感觉看起来不是很直观,不妨看我下面的解析。

以上面的5-way 5-shot例子为例,这里我们简单叙述下MAML的算法流程。


    1. 上面我们已经将数据区分成了${\mathcal D}{meta-train} {\mathcal D}_{meta-test}{\mathcal D}{meta-train} {\mathcal D}_{meta-test} $中我们又将数据区分了support set,query set

    1. 我们用于训练的模型架构是 (假设初始化参数为 ),这可能是一个输出节点为5的CNN,训练的目的是为了使得模型有较优秀初始化参数。最终我们想要学出可以用于数据集${\mathcal D}{meta-test} M{fine-tune} M_{fine-tune}M_{meta}$ 的结构是一模一样的,不同的是模型参数

    1. 我们将1个任务tasksupport set去训练 ,这里进行第一种梯度下降,假设每个任务只进行一次梯度下降,也就是。那么执行第2个task训练时,有。执行第batch size个task后,有,如下图所示。
img

    1. 上述步骤3用了batch size个task对 进行了训练,然后我们使用上述batch个task中地query set去测试参数为模型效果,获得总损失函数,这个损失函数就是一个**batch task**中**每个task**的**query set**在各自参数为 中的损失 之和。

    1. 获得总损失函数后,我们就要对其进行第二种的梯度下降。即更新初始化参数 ,也就是 来更新初始化参数。这样不断地从步骤3开始训练,最终能够在数据集上获得该模型比较好的初始化参数。

    1. 根据这个初始化的参数以及该模型,我们用数据集 的support set对模型进行微调,这时候的梯度下降步数可以设置更多一点,不像训练时候(在第一次梯度下降过程中)只进行一步梯度下降。

    1. 最后微调结束后,使用 的query set进行模型的评估。

至此,MAML的算法流程基本就结束了。

可以看出,每个batch个task中进行batch次第一种梯度下降以及一次第二种梯度下降。

2.4 梯度近似计算

上述算法流程结束后,我们可以获得三个等式,借用李宏毅老师课堂ppt,三个等式如下。

img

那么我们需要求解

公式1

这三个等式的第一个就是最终我们需要求解的等式。那么这个等式中最重要的就是总损失对 的梯度计算,即

公式2

然后我们对上述公式里面的梯度进行拆分计算,即的梯度计算,为

公式3

上面是把 拆成了一个个的标量,然后分别计算后再整合

拆完 后,我们拆 ,如下图所示。

img

那么根据链式法则,可得

公式4

根据上述三个等式中的最后一个,也就是

公式5

我们将 进行对应,获得以下公式

公式6

接着分别公式4中的 i 和 j 进行分析。获得如下等式。

公式7

那么作者考虑到这个二次微分不好计算,就假设这个二次微分为0来进行近似计算。如下

公式8

那么将公式8的近似结果带入上面的公式4中。那么公式4就可以化简为

公式9

这种近似在论文中也体现了出来,如下

we also include a comparison to dropping this backward pass and using a first-order approximation。

这个 a first-order approximation就是对二次微分的忽略

那么根据公式9,公式3可以变化为

那么将公式10带入公式2中,就可以简化梯度计算了。

至此有关MAML的解析就结束了。

3.总结

MAML是meta learning领域非常重要的一种算法。本文主要从原理的角度,结合了一些前人的经验,展开了对MAML的解析,我们发现本文最后的梯度计算中直接忽略二次微分,这样的这样的做法看似比较“鲁莽”,后面将结合下一个meta learning 算法,即Reptile,对MAML这个“鲁莽”行为带来的后果进行分析。

本站仅提供存储服务,所有内容均由用户发布,如发现有害或侵权内容,请点击举报
打开APP,阅读全文并永久保存 查看更多类似文章
猜你喜欢
类似文章
【热】打开小程序,算一算2024你的财运
MAML-Tracker:用目标检测思路做目标跟踪?小样本即可得高准确率丨CVPR 2020
Meta Learning 4: 基于优化的方法
Per-FedAvg:联邦个性化元学习
手把手 | OpenAI开发可拓展元学习算法Reptile,能快速学习(附代码)
元图:通过元学习进行小样本的链接预测
Machine Learning Notes | 将复杂的机器学习算法说简单
更多类似文章 >>
生活服务
热点新闻
分享 收藏 导长图 关注 下载文章
绑定账号成功
后续可登录账号畅享VIP特权!
如果VIP功能使用有故障,
可点击这里联系客服!

联系客服