定义

Few-shot Learning 是 Meta Learning 在监督学习领域的应用

也就是说一般看到few shot 就等于meta learning

  • 人类非常擅长通过极少量的样本识别一个新物体,比如小孩子只需要书中的一些图片就可以认识什么是“斑马”,什么是“犀牛”。在人类的快速学习能力的启发下,研究人员希望机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是 Few-shot Learning 要解决的问题。

  • 实际上的操作来说, 对于分类任务, 我训练集判断同类别的相似度, 测试时拿出没见过的类别, 同时给定support set, 通过训练集得到的网络提取特征, 然后和support set中的类判断相似度, 相似的就定为一个类

  • support set 是指的小数据集

在这里插入图片描述

训练

  • 关键点在于判断相似
训练过程

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xGjuSRK7-1640002637806)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220171122302.png)]

在这里插入图片描述

在这个例子中, 训练集可以训练得到一个判断相似度网络, 然后拿不在训练集的query去与support set中的数据做相似度判断, 其中 K-way 代表support set有k个类, n-shot 代表每个类有几个samples

下面这个图就是four way two shot的

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kYojptyF-1640002637810)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220171422818.png)]

模型分类

**Few-shot Learning模型大致可分为三类:**Mode Based,Metric Based 和 Optimization Based。

  • Model Based 方法旨在通过模型结构的设计快速在少量样本上更新参数,直接建立输入 x 和预测值 P 的映射函数

  • Metric Based 方法通过度量 batch 集中的样本和 support 集中样本的距离,借助最近邻的思想完成分类

  • Optimization Based 方法认为普通的梯度下降方法难以在 few-shot 场景下拟合,因此通过调整优化方法来完成小样本分类的任务。

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-k47H7fwQ-1640002637811)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/format,png.png)]

主要介绍metric based的方法

该方法是对样本间距离分布进行建模,使得同类样本靠近,异类样本远离。主要是通过孪生网络Siamese Network)通过有监督的方式训练孪生网络来学习,然后重用网络所提取的特征进行 one/few-shot 学习

Saiamese Network
一般训练过程
  1. 首先构建正负样本, 同类照片为1, 不同类为0

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-biIsc14V-1640002637812)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220185010522.png)]

  2. 把数据通过同一个卷积网络, 然后把得到的特征判断相似度, 如果是接近1则代表距离越, 接近0则代表距离越loss就是标签与预测之前的cross entropy

    [外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hTSDCrbP-1640002637813)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220185844664.png)]

  3. 训练的时候就可以使用Query进行查询了

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UZZtUpic-1640002637814)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220190016368.png)]

Tripet Loss
  1. 构建一个带有正负样本的三元组, 这样对一个样本的判断力就高了, 正样本距离小,负样本距离大

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Jq1xwuIQ-1640002637814)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220190209995.png)]

  1. 定义损失函数

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vIWKIJLl-1640002637815)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220193241736.png)]

有个margin, 假如负样本距离大于正样本加个margin, 就是分类正确的, 否则就把loss定义为差的形式

希望loss越小越好

  1. 测试,用query和support set

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-zfgLUO7h-1640002637816)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220193445857.png)]

Pretraining and Fine Tuning

  • 基本想法是在大规模数据集上预训练模型, 然后在小规模的support set上做fine-tuning
Cosine Similarity
  • 对于单位向量, 就是直接算它们的内积

在这里插入图片描述

在这里插入图片描述

  • 如果不是单位向量就用下面的方法:
    在这里插入图片描述

其实就是一种求夹角余弦的方法

Softmax Classifier

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-v8Xg8vvv-1640002637818)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220163155001.png)]

  • 假如类别数量等于K, 那么向量p就是K维的,
Fine Tuning
这是没有Fine—Tuning
  1. 训练一个提取特征的网络, 可以用孪生网络, 也可以用基础的有监督的卷积,然后把全连接层去掉

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-TbKmSqdY-1640002637819)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220195708126.png)]

  1. 使用预训练的网络对Support set的图片提特征然后求平均归一化
  • 假如这是个3way2shot的, 就如下图:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NoZ4nwgf-1640002637819)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220195846618.png)]

  1. 把query输入预训练网络得到特征然后归一化为d维度单位向量, 然后这个特征乘M, 也就是support set中的数据归一化后的向量(这里做的是内积,越相似越大)

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-kWtljZ7O-1640002637820)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220200140269.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ixb98mr1-1640002637821)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220200210291.png)]

下面有Fine-Tuning
  • 刚才固定了W和b, W是support set的向量, b是0
  • 现在初始化W和b为上面的, 但是会在support上学习W和b, 这就是fine tuning

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-M2HpM7b6-1640002637821)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220200339597.png)]

那怎么学习呢?需要下面的loss

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-2ejbtrC5-1640002637822)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220200610553.png)]

判断support中所有数据真实标签和预测概率的crossentropy, 加reg防止过拟合

这里的反向传播还可以加到预训练层更好的提特征

Entropy Regularization熵正则化

这个是利用了信息熵

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OrA9w8MY-1640002637823)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220200922507.png)]

  • 信息熵越小, 越符合我们的认知

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-mDWVpoes-1640002637823)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220201144418.png)]

总结

  1. 预训练一个提取特征的模型
  2. 初始化W和b, 求解query的概率, 其中W的维度和support set的类的个数有关

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-urPfTKUx-1640002637824)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220201500238.png)]

  1. 对预测概览和标签做entropy loss

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qEn2X1bX-1640002637825)(F:/%E7%A0%94%E7%A9%B6%E7%94%9F%E6%95%B4%E7%90%86%E4%BF%A1%E6%81%AF/typora_pic/image-20211220201558522.png)]

  1. 对W和b反向传播, 也可以传播到预训练阶段的特征提取网络
Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐