mxnet实战系列(一)入门与跑mnist数据集

最近在摸mxnet和tensorflow。两个我都搭起来了。tensorflow跑了不少代码,总的来说用得比较顺畅,文档很丰富,api熟悉熟悉写代码没什么问题。

今天把两个平台做了一下对比。同是跑mnist,tensorflow 要比mxnet 慢一二十倍。mxnet只需要半分钟,tensorflow跑了13分钟。

在mxnet中如何开跑?

cd /mxnet/example/image-classification
python train_mnist.py

我用的是最新的mxnet版本。运行脚本它会自动下载数据集。然后刷刷刷的刷屏了。我们来看看这个脚本如何写的,从而建立mxnet编程思路:import find_mxnetimport mxnet as mximport argparseimport os, sysimport train_model

def _download(data_dir):    if not os.path.isdir(data_dir):        os.system("mkdir " + data_dir)    os.chdir(data_dir)    if (not os.path.exists(‘train-images-idx3-ubyte‘)) or \       (not os.path.exists(‘train-labels-idx1-ubyte‘)) or \       (not os.path.exists(‘t10k-images-idx3-ubyte‘)) or \       (not os.path.exists(‘t10k-labels-idx1-ubyte‘)):        os.system("wget http://data.dmlc.ml/mxnet/data/mnist.zip")        os.system("unzip -u mnist.zip; rm mnist.zip")    os.chdir("..")

def get_loc(data, attr={‘lr_mult‘:‘0.01‘}):    """    the localisation network in lenet-stn, it will increase acc about more than 1%,    when num-epoch >=15    """    loc = mx.symbol.Convolution(data=data, num_filter=30, kernel=(5, 5), stride=(2,2))    loc = mx.symbol.Activation(data = loc, act_type=‘relu‘)    loc = mx.symbol.Pooling(data=loc, kernel=(2, 2), stride=(2, 2), pool_type=‘max‘)    loc = mx.symbol.Convolution(data=loc, num_filter=60, kernel=(3, 3), stride=(1,1), pad=(1, 1))    loc = mx.symbol.Activation(data = loc, act_type=‘relu‘)    loc = mx.symbol.Pooling(data=loc, global_pool=True, kernel=(2, 2), pool_type=‘avg‘)    loc = mx.symbol.Flatten(data=loc)    loc = mx.symbol.FullyConnected(data=loc, num_hidden=6, name="stn_loc", attr=attr)    return loc

def get_mlp():    """    multi-layer perceptron    """    data = mx.symbol.Variable(‘data‘)    fc1  = mx.symbol.FullyConnected(data = data, name=‘fc1‘, num_hidden=128)    act1 = mx.symbol.Activation(data = fc1, name=‘relu1‘, act_type="relu")    fc2  = mx.symbol.FullyConnected(data = act1, name = ‘fc2‘, num_hidden = 64)    act2 = mx.symbol.Activation(data = fc2, name=‘relu2‘, act_type="relu")    fc3  = mx.symbol.FullyConnected(data = act2, name=‘fc3‘, num_hidden=10)    mlp  = mx.symbol.SoftmaxOutput(data = fc3, name = ‘softmax‘)    return mlp

def get_lenet(add_stn=False):    """    LeCun, Yann, Leon Bottou, Yoshua Bengio, and Patrick    Haffner. "Gradient-based learning applied to document recognition."    Proceedings of the IEEE (1998)    """    data = mx.symbol.Variable(‘data‘)    if(add_stn):        data = mx.sym.SpatialTransformer(data=data, loc=get_loc(data), target_shape = (28,28),                                         transform_type="affine", sampler_type="bilinear")    # first conv    conv1 = mx.symbol.Convolution(data=data, kernel=(5,5), num_filter=20)    tanh1 = mx.symbol.Activation(data=conv1, act_type="tanh")    pool1 = mx.symbol.Pooling(data=tanh1, pool_type="max",                              kernel=(2,2), stride=(2,2))    # second conv    conv2 = mx.symbol.Convolution(data=pool1, kernel=(5,5), num_filter=50)    tanh2 = mx.symbol.Activation(data=conv2, act_type="tanh")    pool2 = mx.symbol.Pooling(data=tanh2, pool_type="max",                              kernel=(2,2), stride=(2,2))    # first fullc    flatten = mx.symbol.Flatten(data=pool2)    fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=500)    tanh3 = mx.symbol.Activation(data=fc1, act_type="tanh")    # second fullc    fc2 = mx.symbol.FullyConnected(data=tanh3, num_hidden=10)    # loss    lenet = mx.symbol.SoftmaxOutput(data=fc2, name=‘softmax‘)    return lenet

def get_iterator(data_shape):    def get_iterator_impl(args, kv):        data_dir = args.data_dir        if ‘://‘ not in args.data_dir:            _download(args.data_dir)        flat = False if len(data_shape) == 3 else True

        train           = mx.io.MNISTIter(            image       = data_dir + "train-images-idx3-ubyte",            label       = data_dir + "train-labels-idx1-ubyte",            input_shape = data_shape,            batch_size  = args.batch_size,            shuffle     = True,            flat        = flat,            num_parts   = kv.num_workers,            part_index  = kv.rank)

        val = mx.io.MNISTIter(            image       = data_dir + "t10k-images-idx3-ubyte",            label       = data_dir + "t10k-labels-idx1-ubyte",            input_shape = data_shape,            batch_size  = args.batch_size,            flat        = flat,            num_parts   = kv.num_workers,            part_index  = kv.rank)

        return (train, val)    return get_iterator_impl

def parse_args():    parser = argparse.ArgumentParser(description=‘train an image classifer on mnist‘)    parser.add_argument(‘--network‘, type=str, default=‘mlp‘,                        choices = [‘mlp‘, ‘lenet‘, ‘lenet-stn‘],                        help = ‘the cnn to use‘)    parser.add_argument(‘--data-dir‘, type=str, default=‘mnist/‘,                        help=‘the input data directory‘)    parser.add_argument(‘--gpus‘, type=str,                        help=‘the gpus will be used, e.g "0,1,2,3"‘)    parser.add_argument(‘--num-examples‘, type=int, default=60000,                        help=‘the number of training examples‘)    parser.add_argument(‘--batch-size‘, type=int, default=128,                        help=‘the batch size‘)    parser.add_argument(‘--lr‘, type=float, default=.1,                        help=‘the initial learning rate‘)    parser.add_argument(‘--model-prefix‘, type=str,                        help=‘the prefix of the model to load/save‘)    parser.add_argument(‘--save-model-prefix‘, type=str,                        help=‘the prefix of the model to save‘)    parser.add_argument(‘--num-epochs‘, type=int, default=10,                        help=‘the number of training epochs‘)    parser.add_argument(‘--load-epoch‘, type=int,                        help="load the model on an epoch using the model-prefix")    parser.add_argument(‘--kv-store‘, type=str, default=‘local‘,                        help=‘the kvstore type‘)    parser.add_argument(‘--lr-factor‘, type=float, default=1,                        help=‘times the lr with a factor for every lr-factor-epoch epoch‘)    parser.add_argument(‘--lr-factor-epoch‘, type=float, default=1,                        help=‘the number of epoch to factor the lr, could be .5‘)    return parser.parse_args()

if __name__ == ‘__main__‘:    args = parse_args()

    if args.network == ‘mlp‘:        data_shape = (784, )        net = get_mlp()    elif args.network == ‘lenet-stn‘:        data_shape = (1, 28, 28)        net = get_lenet(True)    else:        data_shape = (1, 28, 28)        net = get_lenet()

    # train    train_model.fit(args, net, get_iterator(data_shape))

先看Main函数,就是读配置参数,读网络结构,包括设置数据的大小,然后就是调用已有的包train_model。然后传入这之前设置的三个参数。就开始训练了。编程架构也蛮清晰的。模块化也搞的好。接着看看参数设置问题。参数导入了很多配置文件,基本上caffe中的Proto都在这个里面设置了。包括数据集地址,批大小,学习率,损失函数,等等。然后看看读网络结构,读网络结构就是在一层一层的搭积木,根据之前读入的配置文件或者自己定义一些参数。搭好积木就开始训练了。caffe的一个缺点是不够灵活,毕竟不是自己写代码,只是写配置文件,总感觉受制于人。mxnet和tensorflow就比较方便,提供api,你可以按你的方式来调用和定义网络结构。总的说来,其实是后两个框架模块化做的好,提供底层的api支持你写自己的网络。caffe要自己写网络层的话还是很费劲的
时间: 2024-10-27 12:28:06

mxnet实战系列(一)入门与跑mnist数据集的相关文章

Spark入门实战系列--1.Spark及其生态圈简介

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架.Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处,Spark以其先进的设计理念,迅速成为社区的热门项目,围绕着Spark推出了Spark SQL.Spark St

Spark入门实战系列--6.SparkSQL(中)--深入了解运行计划及调优

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.1  运行环境说明 1.1.1 硬软件环境 l  主机操作系统:Windows 64位,双核4线程,主频2.2G,10G内存 l  虚拟软件:VMware® Workstation 9.0.0 build-812388 l  虚拟机操作系统:CentOS6.5 64位,单核 l  虚拟机运行环境: Ø  JDK:1.7.0_55 64位 Ø  Hadoop:2.2.0(需要编译为64位) Ø 

Spark入门实战系列--8.Spark MLlib(上)--机器学习及SparkMLlib简介

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.机器学习概念 1.1 机器学习的定义 在维基百科上对机器学习提出以下几种定义: l“机器学习是一门人工智能的科学,该领域的主要研究对象是人工智能,特别是如何在经验学习中改善具体算法的性能”. l“机器学习是对能通过经验自动改进的计算机算法的研究”. l“机器学习是用数据或以往的经验,以此优化计算机程序的性能标准.” 一种经常引用的英文定义是:A computer program is said

Spark入门实战系列--2.Spark编译与部署(下)--Spark编译安装

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.编译Spark Spark可以通过SBT和Maven两种方式进行编译,再通过make-distribution.sh脚本生成部署包.SBT编译需要安装git工具,而Maven安装则需要maven工具,两种方式均需要在联网下进行,通过比较发现SBT编译速度较慢(原因有可能是1.时间不一样,SBT是白天编译,Maven是深夜进行的,获取依赖包速度不同 2.maven下载大文件是多线程进行,而SBT是

Spark入门实战系列--7.Spark Streaming(下)--实时流计算Spark Streaming实战

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.实例演示 1.1 流数据模拟器 1.1.1 流数据说明 在实例演示中模拟实际情况,需要源源不断地接入流数据,为了在演示过程中更接近真实环境将定义流数据模拟器.该模拟器主要功能:通过Socket方式监听指定的端口号,当外部程序通过该端口连接并请求数据时,模拟器将定时将指定的文件数据随机获取发送给外部程序. 1.1.2 模拟器代码 import java.io.{PrintWriter} impor

Spark入门实战系列--2.Spark编译与部署(中)--Hadoop编译安装

[注]该系列文章以及使用到安装包/測试数据 能够在<[倾情大奉送–Spark入门实战系列] (http://blog.csdn.net/yirenboy/article/details/47291765)>获取 1 编译Hadooop 1.1 搭建好开发环境 1.1.1 安装并设置maven 1.下载maven安装包.建议安装3.0以上版本号,本次安装选择的是maven3.0.5的二进制包,下载地址例如以下 http://mirror.bit.edu.cn/apache/maven/maven

Spark入门实战系列--9.Spark图计算GraphX介绍及实例

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.GraphX介绍 1.1 GraphX应用背景 Spark GraphX是一个分布式图处理框架,它是基于Spark平台提供对图计算和图挖掘简洁易用的而丰富的接口,极大的方便了对分布式图处理的需求. 众所周知·,社交网络中人与人之间有很多关系链,例如Twitter.Facebook.微博和微信等,这些都是大数据产生的地方都需要图计算,现在的图处理基本都是分布式的图处理,而并非单机处理.Spark

Spark入门实战系列--5.Hive(上)--Hive介绍及部署

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.Hive介绍 1.1 Hive介绍 Hive是一个基于Hadoop的开源数据仓库工具,用于存储和处理海量结构化数据.它是Facebook 2008年8月开源的一个数据仓库框架,提供了类似于SQL语法的HQL语句作为数据访问接口,Hive有如下优缺点: l  优点: 1.Hive 使用类SQL 查询语法, 最大限度的实现了和SQL标准的兼容,大大降低了传统数据分析人员学习的曲线: 2.使用JDBC

Spark入门实战系列--8.Spark MLlib(下)--机器学习库SparkMLlib实战

[注]该系列文章以及使用到安装包/测试数据 可以在<倾情大奉送--Spark入门实战系列>获取 1.MLlib实例 1.1 聚类实例 1.1.1 算法说明 聚类(Cluster analysis)有时也被翻译为簇类,其核心任务是:将一组目标object划分为若干个簇,每个簇之间的object尽可能相似,簇与簇之间的object尽可能相异.聚类算法是机器学习(或者说是数据挖掘更合适)中重要的一部分,除了最为简单的K-Means聚类算法外,比较常见的还有层次法(CURE.CHAMELEON等).网