Tensorflow自带的Mnist数据集相关情况
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
#数据会自动在线下载,第一次较慢,第二次之后就好了
mnist = input_data.read_data_sets(‘data/‘,one_hot=True)
print(type(mnist))
print(mnist.train.num_examples)#55000
print(mnist.test.num_examples)#10000
img_train = mnist.train.images
label_train = mnist.train.labels
img_test = mnist.test.images
label_test = mnist.test.labels
print(type(img_train))#<class ‘numpy.ndarray‘>
print(type(label_train))#<class ‘numpy.ndarray‘>
print(type(img_test))#<class ‘numpy.ndarray‘>
print(type(label_test))#<class ‘numpy.ndarray‘>
print(img_train.shape)#(55000, 784) 28*28的图片
print(label_train.shape)#(55000, 10)
print(img_test.shape)#(10000, 784)
print(label_test.shape)#(10000, 10) #one hot coding便于取最大概率
num_sample = 5
rand_idx = np.random.randint(img_train.shape[0], size=num_sample)
for i in rand_idx:
cur_img = np.reshape(img_train[i, :],(28,28))
cur_label = np.argmax(label_train[i,:])
plt.matshow(cur_img, cmap = plt.get_cmap(‘gray‘))
print(str(i) + "训练数据的标签是" + str(cur_label))
# plt.show()
#取batch数据
batch_size = 100
batch_x, batch_y = mnist.train.next_batch(batch_size)
print(type(batch_x))#<class ‘numpy.ndarray‘>
print(type(batch_y))#<class ‘numpy.ndarray‘>
print(batch_x.shape)#(100, 784)
print(batch_y.shape)#(100, 10)
原文地址:https://blog.51cto.com/5669384/2415956
时间: 2024-11-05 14:49:17