python读取MNIST image数据

Lecun Mnist数据集下载

import numpy as np
import struct

def loadImageSet(which=0):
    print "load image set"
    binfile=None
    if which==0:
        binfile = open("..//dataset//train-images-idx3-ubyte", ‘rb‘)
    else:
        binfile=  open("..//dataset//t10k-images-idx3-ubyte", ‘rb‘)
    buffers = binfile.read()

    head = struct.unpack_from(‘>IIII‘ , buffers ,0)
    print "head,",head

    offset=struct.calcsize(‘>IIII‘)
    imgNum=head[1]
    width=head[2]
    height=head[3]
    #[60000]*28*28
    bits=imgNum*width*height
    bitsString=‘>‘+str(bits)+‘B‘ #like ‘>47040000B‘

    imgs=struct.unpack_from(bitsString,buffers,offset)

    binfile.close()
    imgs=np.reshape(imgs,[imgNum,width,height])
    print "load imgs finished"
    return imgs

def loadLabelSet(which=0):
    print "load label set"
    binfile=None
    if which==0:
        binfile = open("..//dataset//train-labels-idx1-ubyte", ‘rb‘)
    else:
        binfile=  open("..//dataset//t10k-labels-idx1-ubyte", ‘rb‘)
    buffers = binfile.read()

    head = struct.unpack_from(‘>II‘ , buffers ,0)
    print "head,",head
    imgNum=head[1]

    offset = struct.calcsize(‘>II‘)
    numString=‘>‘+str(imgNum)+"B"
    labels= struct.unpack_from(numString , buffers , offset)
    binfile.close()
    labels=np.reshape(labels,[imgNum,1])

    #print labels
    print ‘load label finished‘
    return labels

if __name__=="__main__":
    imgs=loadImageSet()
    #import PlotUtil as pu
    #pu.showImgMatrix(imgs[0])
    loadLabelSet()

及方便训练的reader

import numpy as np
import struct
import gzip
import cPickle

class MnistReader():

    def __init__(self,mnist_path,data_dim=1,one_hot=True):
        ‘‘‘
        mnist_path: the path of mnist.pkl.gz
        data_dim=1 [N,784]
        data_dim=3 [N,28,28,1]
        one_hot: one hot encoding(like: [0,1,0,0,0,0,0,0,0,0]) if true
        ‘‘‘
        self.mnist_path=mnist_path
        self.data_dim=data_dim
        self.one_hot=one_hot
        self.load_minist(mnist_path)

        self.train_datalabel=zip(self.train_x,self.train_y)
        self.valid_datalabel=zip(self.valid_x,self.valid_y)

        self.batch_offset_train=0

    def next_batch_train(self,batch_size):
        ‘‘‘
        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
               and list of labels with shape [N] or [N,10] dependents on self.one_hot
        ‘‘‘
        if self.batch_offset_train<len(self.train_datalabel)//batch_size:
            imgs=list();labels=list()
            for d,l in self.train_datalabel[self.batch_offset_train:self.batch_offset_train+batch_size]:
                if self.data_dim==3:
                    d=np.reshape(d, [28,28,1])
                imgs.append(d)
                if self.one_hot:
                    a=np.zeros(10)
                    a[l]=1
                    labels.append(l)
                else:
                    labels.append(l)
            self.batch_offset_train+=1
            return imgs,labels
        else:
            self.batch_offset_train=0
            np.random.shuffle(self.train_datalabel)
            return self.next_batch_train(batch_size)

    def next_batch_val(self,batch_size):
        ‘‘‘
        return list of images with shape [N,784] or [N,28,28,1] dependents on self.data_dim
               and list of labels with shape [N,1] or [N,10] dependents on self.one_hot
        ‘‘‘
        np.random.shuffle(self.valid_datalabel)
        imgs=list();labels=list()
        for d,l in self.train_datalabel[0:batch_size]:
            if self.data_dim==3:
                d=np.reshape(d, [28,28,1])
            imgs.append(d)
            if self.one_hot:
                a=np.zeros(10)
                a[l]=1
                labels.append(l)
            else:
                labels.append(l)
        return imgs,labels

    def load_minist(self,dataset):
        print "load dataset"
        f = gzip.open(dataset, ‘rb‘)
        train_set, valid_set, test_set = cPickle.load(f)
        f.close()
        self.train_x,self.train_y=train_set
        self.valid_x,self.valid_y=valid_set
        self.test_x , self.test_y=test_set
        print "train image,label shape:",self.train_x.shape,self.train_y.shape
        print "valid image,label shape:",self.valid_x.shape,self.valid_y.shape
        print "test  image,label shape:",self.test_x.shape,self.test_y.shape
        print "load dataset end"

if __name__=="__main__":
    mnist=MnistReader(‘../dataset/mnist.pkl.gz‘,data_dim=3)
    data,label=mnist.next_batch_train(batch_size=1)
    print data
    print label

第三种加载方式需要 gzip和struct

import gzip, struct

def _read(image,label):
    minist_dir = ‘your_dir/‘
    with gzip.open(minist_dir+label) as flbl:
        magic, num = struct.unpack(">II", flbl.read(8))
        label = np.fromstring(flbl.read(), dtype=np.int8)
    with gzip.open(minist_dir+image, ‘rb‘) as fimg:
        magic, num, rows, cols = struct.unpack(">IIII", fimg.read(16))
        image = np.fromstring(fimg.read(), dtype=np.uint8).reshape(len(label), rows, cols)
    return image,label

def get_data():
    train_img,train_label = _read(
            ‘train-images-idx3-ubyte.gz‘,
            ‘train-labels-idx1-ubyte.gz‘)
    test_img,test_label = _read(
            ‘t10k-images-idx3-ubyte.gz‘,
            ‘t10k-labels-idx1-ubyte.gz‘)
    return [train_img,train_label,test_img,test_label]

原文地址:https://www.cnblogs.com/judejie/p/9143974.html

时间: 2024-08-06 22:17:39

python读取MNIST image数据的相关文章

python读取mnist

python读取mnist 其实就是python怎么读取binnary file mnist的结构如下,选取train-images TRAINING SET IMAGE FILE (train-images-idx3-ubyte): [offset] [type]          [value]          [description] 0000     32 bit integer  0x00000803(2051) magic number 0004     32 bit integ

Python读取MNIST数据集

MNIST数据集获取 MNIST数据集是入门机器学习/模式识别的最经典数据集之一.最早于1998年Yan Lecun在论文: Gradient-based learning applied to document recognition. 中提出.经典的LeNet-5 CNN网络也是在该论文中提出的. 数据集包含了0-9共10类手写数字图片,每张图片都做了尺寸归一化,都是28x28大小的灰度图.每张图片中像素值大小在0-255之间,其中0是黑色背景,255是白色前景.如下图所示: MNIST共包

python读取EXCLE文件数据

python读取EXCEL,利用 Google 搜索 Python Excel,点击第一条结果http://www.python-excel.org/ ,能够跨平台处理 Excel. 按照文档一步步去做,要安装 三个包: xlrd(用于读Excel): xlwt(用于写Excel): xlutils(处理Excel的工具箱) 1 from xlrd import open_workbook 2 import re 3 4 #创建一个用于读取sheet的生成器,依次生成每行数据,row_count

Python读取SQLite文件数据

近日在做项目时,意外听说有一种SQLite的数据库,相比自己之前使用的SQL Service甚是轻便,在对数据完整性.并发性要求不高的场景下可以尝试! 1.SQLite简介: SQLite是一个进程内的库,实现了自给自足的.无服务器的.零配置的.事务性的 SQL 数据库引擎.它的设计目标是嵌入式的,而且目前已经在很多嵌入式产品中使用了它(如安卓系统),它占用资源非常的低,在嵌入式设备中,可能只需要几百K的内存就够了.它能够支持Windows/Linux/Unix等等主流的操作系统,同时能够跟很多

python读取数据库表数据并写入excel

一个简单的使用python读取mysql数据并写入excel中实例 1.python连接mysql数据库 conn = pymysql.connect(user='root',host='127.0.0.1',port=3306,passwd='root',db='python',charset='utf8') #连接数据库 cur = conn.cursor() 2.读取mysql数据库中表数据 1 sql = 'select * from %s;' %table_name #需要写入exce

python读取txt天气数据并使用matplotlib模块绘图

天气数据可以从网上下载,这个例子的数据是从http://data.cma.cn/下载而来的. 下载的数据装在txt文件中. 里面包含了12年开始北京的月最低和最高温度. 读取数据: 1 with open('S201812261702093585500.txt') as file_object: 2 lines=file_object.readlines() 将txt中的数据逐行存到列表lines里 lines的每一个元素对应于txt中的一行.然后将每个元素中的不同信息提取出来: 1 file1

python读取mnist label数据库

<br>[offset] [type] [value] [description] 0000 32 bit integer 0x00000803(2051) magic number 0004 32 bit integer 60000 number of items 0008 unsigned byte ?? label 0009 unsigned byte ?? label ........ xxxx unsigned byte ?? label Mnist label数据结构如上. 完整代

ean13码的生成,python读取csv中数据并处理返回并写入到另一个csv文件中

# -*- coding: utf-8 -*- import math import re import csv import repr def ean_checksum(eancode): """returns the checksum of an ean string of length 13, returns -1 if the string has the wrong length""" if len(eancode) != 13: re

mnist的格式说明,以及在python3.x和python 2.x读取mnist数据集的不同

#!/usr/bin/env python # -*- coding: UTF-8 -*- import struct # from bp import * from datetime import datetime # 数据加载器基类 class Loader(object): def __init__(self, path, count): ''' 初始化加载器 path: 数据文件路径 count: 文件中的样本个数 ''' self.path = path self.count = co