原文是:《The Impact of Imbalanced Training Data for Convolutional Neural Networks》
本博客是该论文的阅读笔记,不免有很多细节不对之处。
还望各位看官能够见谅,欢迎批评指正。
更多相关博客请猛戳:http://blog.csdn.net/cyh_24
如需转载,请附上本文链接:http://blog.csdn.net/cyh_24/article/details/49871387
Abstract
本文主要研究使用不平衡数据训练CNN对图像分类的影响。文中使用的数据集是CIFAR-10,作者使用这个数据库,人工地对不同类别生成不同数量分布的数据。比如,让一个类别的图像占很大的比例,而另一类占很小的比例。使用这些生成的不同的训练集,均去训练一个CNN,并测试得到相应的准确率。
结果显示,不平衡训练集会对结果造成很大的负面影响,而训练集在平衡的情况下,能够达到最好的performance。
并且,文中得出一个结论:oversampling是一个很好的效的方式来解决不平衡训练集的问题。
实验过程
Dataset
使用的数据集是CIFAR-10,该数据集有10个类,每类6000张,共6w张图像。
对CIFAR-10进行数据切分,使用其中的5000张作为训练,1000作为测试图像。
生成不同数据分布
解释一下上图:
- Dist.1 是balanced data,每个类都占10%比重;
- Dist.2表明airplane,automobile,bird和cat各占8%,而其他类别各占12%…这个应该能看懂吧。
所以,现在有了11个训练集,接下来使用相同的CNN来训练,还是使用原来的test data进行测试。
Oversampling
文中使用的oversampling方式非常简单:
对于每一类,随机选出一些图片进行复制,直到该类图片数量与占最大比重的图片相等。
Results
Distribution Performace
Oversampling Performance
以上是经过oversampling之后的训练的CNN的performance,可以看出,几乎每个类都有提升,不过Dist.1(balanced training data)还是最高的。
Total Performance
平均以下每个Dist的准确率,得到如下表所示的准确率比较图,深色是imbalanced 的准确率,浅色是oversampling之后的准确率。
文章目标很明确,思路也很简单,并没有其他trick,我也就讲到这了。
总结一下,文章讲的事情和结论:
- 训练数据分布情况对CNN结果产生很大影响;
- 显然,balanced训练集是最优的,数据越不平衡,准确率越差;
- 使用Oversampling能够提升准确率;
版权声明:如需转载,请附上本文链接,不甚感激!作者主页:http://blog.csdn.net/cyh_24