tiny-cnn开源库的使用(MNIST)

tiny-cnn是一个基于CNN的开源库,它的License是BSD 3-Clause。作者也一直在维护更新,对进一步掌握CNN很有帮助,因此下面介绍下tiny-cnn在windows7 64bit vs2013的编译及使用。

1.      从https://github.com/nyanp/tiny-cnn下载源码:

$ git clone https://github.com/nyanp/tiny-cnn.git  版本号为77d80a8,更新日期2016.01.22

2.      源文件中已经包含了vs2013工程,vc/tiny-cnn.sln,默认是win32的,examples/main.cpp需要OpenCV的支持,这里新建一个x64的控制台工程tiny-cnn;

3.      仿照源工程,将相应.h文件加入到新控制台工程中,新加一个test_tiny-cnn.cpp文件;

4.      将examples/mnist中test.cpp和train.cpp文件中的代码复制到test_tiny-cnn.cpp文件中;

#include <iostream>
#include <string>
#include <vector>
#include <algorithm>
#include <tiny_cnn/tiny_cnn.h>
#include <opencv2/opencv.hpp>

using namespace tiny_cnn;
using namespace tiny_cnn::activation;

// rescale output to 0-100
template <typename Activation>
double rescale(double x)
{
	Activation a;
	return 100.0 * (x - a.scale().first) / (a.scale().second - a.scale().first);
}

void construct_net(network<mse, adagrad>& nn);
void train_lenet(std::string data_dir_path);
// convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img);
void convert_image(const std::string& imagefilename, double minv, double maxv, int w, int h, vec_t& data);
void recognize(const std::string& dictionary, const std::string& filename, int target);

int main()
{
	//train
	std::string data_path = "D:/Download/MNIST";
	train_lenet(data_path);

	//test
	std::string model_path = "D:/Download/MNIST/LeNet-weights";
	std::string image_path = "D:/Download/MNIST/";
	int target[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };

	for (int i = 0; i < 10; i++) {
		char ch[15];
		sprintf(ch, "%d", i);
		std::string str;
		str = std::string(ch);
		str += ".png";
		str = image_path + str;

		recognize(model_path, str, target[i]);
	}

	std::cout << "ok!" << std::endl;
	return 0;
}

void train_lenet(std::string data_dir_path) {
	// specify loss-function and learning strategy
	network<mse, adagrad> nn;

	construct_net(nn);

	std::cout << "load models..." << std::endl;

	// load MNIST dataset
	std::vector<label_t> train_labels, test_labels;
	std::vector<vec_t> train_images, test_images;

	parse_mnist_labels(data_dir_path + "/train-labels.idx1-ubyte",
		&train_labels);
	parse_mnist_images(data_dir_path + "/train-images.idx3-ubyte",
		&train_images, -1.0, 1.0, 2, 2);
	parse_mnist_labels(data_dir_path + "/t10k-labels.idx1-ubyte",
		&test_labels);
	parse_mnist_images(data_dir_path + "/t10k-images.idx3-ubyte",
		&test_images, -1.0, 1.0, 2, 2);

	std::cout << "start training" << std::endl;

	progress_display disp(train_images.size());
	timer t;
	int minibatch_size = 10;
	int num_epochs = 30;

	nn.optimizer().alpha *= std::sqrt(minibatch_size);

	// create callback
	auto on_enumerate_epoch = [&](){
		std::cout << t.elapsed() << "s elapsed." << std::endl;
		tiny_cnn::result res = nn.test(test_images, test_labels);
		std::cout << res.num_success << "/" << res.num_total << std::endl;

		disp.restart(train_images.size());
		t.restart();
	};

	auto on_enumerate_minibatch = [&](){
		disp += minibatch_size;
	};

	// training
	nn.train(train_images, train_labels, minibatch_size, num_epochs,
		on_enumerate_minibatch, on_enumerate_epoch);

	std::cout << "end training." << std::endl;

	// test and show results
	nn.test(test_images, test_labels).print_detail(std::cout);

	// save networks
	std::ofstream ofs("D:/Download/MNIST/LeNet-weights");
	ofs << nn;
}

void construct_net(network<mse, adagrad>& nn) {
	// connection table [Y.Lecun, 1998 Table.1]
#define O true
#define X false
	static const bool tbl[] = {
		O, X, X, X, O, O, O, X, X, O, O, O, O, X, O, O,
		O, O, X, X, X, O, O, O, X, X, O, O, O, O, X, O,
		O, O, O, X, X, X, O, O, O, X, X, O, X, O, O, O,
		X, O, O, O, X, X, O, O, O, O, X, X, O, X, O, O,
		X, X, O, O, O, X, X, O, O, O, O, X, O, O, X, O,
		X, X, X, O, O, O, X, X, O, O, O, O, X, O, O, O
	};
#undef O
#undef X

	// construct nets
	nn << convolutional_layer<tan_h>(32, 32, 5, 1, 6)  // C1, [email protected], [email protected]
		<< average_pooling_layer<tan_h>(28, 28, 6, 2)   // S2, [email protected], [email protected]
		<< convolutional_layer<tan_h>(14, 14, 5, 6, 16,
		connection_table(tbl, 6, 16))              // C3, [email protected], [email protected]
		<< average_pooling_layer<tan_h>(10, 10, 16, 2)  // S4, [email protected], [email protected]
		<< convolutional_layer<tan_h>(5, 5, 5, 16, 120) // C5, [email protected], [email protected]
		<< fully_connected_layer<tan_h>(120, 10);       // F6, 120-in, 10-out
}

void recognize(const std::string& dictionary, const std::string& filename, int target) {
	network<mse, adagrad> nn;

	construct_net(nn);

	// load nets
	std::ifstream ifs(dictionary.c_str());
	ifs >> nn;

	// convert imagefile to vec_t
	vec_t data;
	convert_image(filename, -1.0, 1.0, 32, 32, data);

	// recognize
	auto res = nn.predict(data);
	std::vector<std::pair<double, int> > scores;

	// sort & print top-3
	for (int i = 0; i < 10; i++)
		scores.emplace_back(rescale<tan_h>(res[i]), i);

	std::sort(scores.begin(), scores.end(), std::greater<std::pair<double, int>>());

	for (int i = 0; i < 3; i++)
		std::cout << scores[i].second << "," << scores[i].first << std::endl;

	std::cout << "the actual digit is: " << scores[0].second << ", correct digit is: "<<target<<std::endl;

	// visualize outputs of each layer
	//for (size_t i = 0; i < nn.depth(); i++) {
	//	auto out_img = nn[i]->output_to_image();
	//	cv::imshow("layer:" + std::to_string(i), image2mat(out_img));
	//}
	//// visualize filter shape of first convolutional layer
	//auto weight = nn.at<convolutional_layer<tan_h>>(0).weight_to_image();
	//cv::imshow("weights:", image2mat(weight));

	//cv::waitKey(0);
}

// convert tiny_cnn::image to cv::Mat and resize
cv::Mat image2mat(image<>& img) {
	cv::Mat ori(img.height(), img.width(), CV_8U, &img.at(0, 0));
	cv::Mat resized;
	cv::resize(ori, resized, cv::Size(), 3, 3, cv::INTER_AREA);
	return resized;
}

void convert_image(const std::string& imagefilename,
	double minv,
	double maxv,
	int w,
	int h,
	vec_t& data) {
	auto img = cv::imread(imagefilename, cv::IMREAD_GRAYSCALE);
	if (img.data == nullptr) return; // cannot open, or it‘s not an image

	cv::Mat_<uint8_t> resized;
	cv::resize(img, resized, cv::Size(w, h));

	// mnist dataset is "white on black", so negate required
	std::transform(resized.begin(), resized.end(), std::back_inserter(data),
		[=](uint8_t c) { return (255 - c) * (maxv - minv) / 255.0 + minv; });
}

5.      编译时会提示几个错误,解决方法是:

(1)、error C4996,解决方法:将宏_SCL_SECURE_NO_WARNINGS添加到属性的预处理器定义中;

(2)、调用for_函数时,error C2668,对重载函数的调用不明教,解决方法:将for_中的第三个参数强制转化为size_t类型;

6.      运行程序,train时,运行结果如下图所示:

7.      对生成的model进行测试,通过画图工具,每个数字生成一张图像,共10幅,如下图:

通过导入train时生成的model,对这10张图像进行识别,识别结果如下图,其中6和9被误识为5和1:

时间: 2024-12-20 21:10:42

tiny-cnn开源库的使用(MNIST)的相关文章

C++开源库大全(转)

程序员要站在巨人的肩膀上,C++拥有丰富的开源库,这里包括:标准库.Web应用框架.人工智能.数据库.图片处理.机器学习.日志.代码分析等. 标准库 C++ Standard Library:是一系列类和函数的集合,使用核心语言编写,也是C++ISO自身标准的一部分. Standard Template Library:标准模板库 C POSIX library : POSIX系统的C标准库规范 ISO C++ Standards Committee :C++标准委员会 框架 C++通用框架和库

站在巨人的肩膀上,C++开源库大全

程序员要站在巨人的肩膀上,C++拥有丰富的开源库,这里包括:标准库.Web应用框架.人工智能.数据库.图片处理.机器学习.日志.代码分析等. 标准库 C++ Standard Library:是一系列类和函数的集合,使用核心语言编写,也是C++ISO自身标准的一部分. Standard Template Library:标准模板库 C POSIX library : POSIX系统的C标准库规范 ISO C++ Standards Committee :C++标准委员会 框架 C++通用框架和库

开源库BaseRecyclerViewAdapterHelper

相信大家RecyclerView应该不会陌生,大多数开发者应该都使用上它了,它也是google推荐替换ListView的控件,但是用过它的同学应该都知道它在某些方面并没有ListView使用起来方便,需要我们额外的编写代码,今天就给大家介绍一个开源库BaseRecyclerViewAdapterHelper,有了它让你使用RecyclerView的时候,和ListView一样的好用! 那么你要问了,BaseRecyclerViewAdapterHelper能做什么? 优化Adapter代码(减少

【计算机视觉】OpenCV的最近邻开源库FLANN

FLANN介绍 FLANN库全称是Fast Library for Approximate Nearest Neighbors,它是目前最完整的(近似)最近邻开源库.不但实现了一系列查找算法,还包含了一种自动选取最快算法的机制. flann::Index_类 该类模板是最近邻索引类,该类用于抽象不同类型的最近邻搜索的索引. 以下是flann::Index_类的声明: template <typename T> class #ifndef _MSC_VER FLANN_DEPRECATED #e

GitHub Top 100的Android开源库

本项目主要对目前 GitHub 上排名前 100 的 Android 开源库进行简单的介绍, 至于排名完全是根据GitHub搜索Java语言选择「Best Match」得到的结果,然后过滤了跟Android不相关的项目,所以排名并不具备任何官方效力,仅供参考学习,方便初学者快速了解当前一些流行的Android开源库. 1. React Native 这个是 Facebook 在 React.js Conf 2015 大会上推出的基于 JavaScript 的开源框架 React Native,

【开源框架】Android之史上最全最简单最有用的第三方开源库收集整理,有助于快速开发,欢迎各位...

[转]http://www.tuicool.com/articles/jyA3MrU Android开源库 自己一直很喜欢Android开发,就如博客签名一样, 我是程序猿,我为自己代言 . 在摸索过程中,GitHub上搜集了很多很棒的Android第三方库,推荐给在苦苦寻找的开发者,而且我会 不定期的更新 这篇文章. Android下的优秀开发库数不胜数,在本文中,我列举的多是开发流程中最常用的一些.如果你还想了解更多的Android开源库,可以关注我的博客,每一个库都是我认真查看或者编译运行

在移动开发中常用的开源库总结

1.为什么需要开源库? 我个人觉得有以下几个原因: 1>我们的项目比较赶,但是又用到一些比较复杂的模块,这些模块不是系统自带的,或者说系统自带的满足不了需求,同时在一些开源网站上面又有类似的或者是满足我哦们需求的开源项目和库,拿来就可以减少我们很多的工作量. 2>开源库从另外一方面来说就是为了提高代码的重用性,大家使用了这个开源库,然后提交一些bug,通过大家的力量完善这个开源项目. 2.我常用的开源库? 我使用的一些开源项目主要都是在github上面很热门的项目: 图片加载:Android-

C++开源库,欢迎补充

C++在“商业应用”方面,曾经是天下第一的开发语言,但这一桂冠已经被java抢走多年.因为当今商业应用程序类型,已经从桌面应用迅速转移成Web应 用.当Java横行天下之后,MS又突然发力,搞出C#语言,有大片的曾经的C++程序员,以为C++要就此沉沦,未料,这三年来,C++的生命力突然被 严重地增强了.主力原因就是开源的软件.基础软件(比如并发原生支持,比如Android必定要推出原生的SDK).各种跨平台应用的出现. 开源C++库必须具有以下特点:必须是成熟的产品.跨平台的产品.相对通用的库

android开源库发布到jcenter图文详解与填坑

相信很多人都用过开源项目,特别是android studio普及以后,使用开源库更方便简单.而如何上传开源库到jcenter供大家方便使用,虽然网上也有教程,但还是遇坑了,最后总结一下,希望可以帮助大家. [csdn地址: http://blog.csdn.net/zhangke3016/article/details/52075159] [本文简书地址: http://www.jianshu.com/p/0acf9e05b27e]同步更新 AndroidStudio是从Maven Reposi