什么是 Few-shot learning (小样本学习)?
最近总是看到few-shot learning, 记录一下
文章目录
定义
Few-shot Learning 是 Meta Learning 在监督学习领域的应用
也就是说一般看到few shot 就等于meta learning
-
人类非常擅长通过极少量的样本识别一个新物体,比如小孩子只需要书中的一些图片就可以认识什么是“斑马”,什么是“犀牛”。在人类的快速学习能力的启发下,研究人员希望机器学习模型在学习了一定类别的大量数据后,对于新的类别,只需要少量的样本就能快速学习,这就是 Few-shot Learning 要解决的问题。
-
实际上的操作来说, 对于分类任务, 我训练集判断同类别的相似度, 测试时拿出没见过的类别, 同时给定support set, 通过训练集得到的网络提取特征, 然后和support set中的类判断相似度, 相似的就定为一个类
-
support set 是指的小数据集
训练
- 关键点在于判断相似
训练过程
在这个例子中, 训练集可以训练得到一个判断相似度网络, 然后拿不在训练集的query去与support set中的数据做相似度判断, 其中 K-way 代表support set有k个类, n-shot 代表每个类有几个samples
下面这个图就是four way two shot的
模型分类
**Few-shot Learning模型大致可分为三类:**Mode Based,Metric Based 和 Optimization Based。
-
Model Based 方法旨在通过模型结构的设计快速在少量样本上更新参数,直接建立输入 x 和预测值 P 的映射函数
-
Metric Based 方法通过度量 batch 集中的样本和 support 集中样本的距离,借助最近邻的思想完成分类
-
Optimization Based 方法认为普通的梯度下降方法难以在 few-shot 场景下拟合,因此通过调整优化方法来完成小样本分类的任务。
主要介绍metric based的方法
该方法是对样本间距离分布进行建模,使得同类样本靠近,异类样本远离。主要是通过孪生网络(Siamese Network)通过有监督的方式训练孪生网络来学习,然后重用网络所提取的特征进行 one/few-shot 学习
Saiamese Network
一般训练过程
-
首先构建正负样本, 同类照片为1, 不同类为0
-
把数据通过同一个卷积网络, 然后把得到的特征判断相似度, 如果是接近1则代表距离越近, 接近0则代表距离越远, loss就是标签与预测之前的cross entropy
-
训练的时候就可以使用Query进行查询了
Tripet Loss
- 构建一个带有正负样本的三元组, 这样对一个样本的判断力就高了, 正样本距离小,负样本距离大
- 定义损失函数
有个margin, 假如负样本距离大于正样本加个margin, 就是分类正确的, 否则就把loss定义为差的形式
希望loss越小越好
- 测试,用query和support set
Pretraining and Fine Tuning
- 基本想法是在大规模数据集上预训练模型, 然后在小规模的support set上做fine-tuning
Cosine Similarity
- 对于单位向量, 就是直接算它们的内积
- 如果不是单位向量就用下面的方法:
其实就是一种求夹角余弦的方法
Softmax Classifier
- 假如类别数量等于K, 那么向量p就是K维的,
Fine Tuning
这是没有Fine—Tuning
- 训练一个提取特征的网络, 可以用孪生网络, 也可以用基础的有监督的卷积,然后把全连接层去掉
- 使用预训练的网络对Support set的图片提特征然后求平均归一化
- 假如这是个3way2shot的, 就如下图:
- 把query输入预训练网络得到特征然后归一化为d维度单位向量, 然后这个特征乘M, 也就是support set中的数据归一化后的向量(这里做的是内积,越相似越大)
下面有Fine-Tuning
- 刚才固定了W和b, W是support set的向量, b是0
- 现在初始化W和b为上面的, 但是会在support上学习W和b, 这就是fine tuning
那怎么学习呢?需要下面的loss
判断support中所有数据真实标签和预测概率的crossentropy, 加reg防止过拟合
这里的反向传播还可以加到预训练层更好的提特征
Entropy Regularization熵正则化
这个是利用了信息熵
- 信息熵越小, 越符合我们的认知
总结
- 预训练一个提取特征的模型
- 初始化W和b, 求解query的概率, 其中W的维度和support set的类的个数有关
- 对预测概览和标签做entropy loss
- 对W和b反向传播, 也可以传播到预训练阶段的特征提取网络
更多推荐
所有评论(0)