读取多张MNIST图片与利用BaseEstimator基类创建分类器

读取多张MNIST图片

在读取多张MNIST图片之前,我们先来看下读取单张图片如何实现

每张数字图片大小都为28 * 28的,需要将数据reshape成28 * 28的,采用最近邻插值,如下

def plot_digit(data):
    img = data.reshape(28,28)
    plt.imshow(img,cmap=matplotlib.cm.binary,interpolation=‘nearest‘)
    plt.axis(‘off‘)
import matplotlib.pyplot as plt
import matplotlib
some_digit = X[36000]
plot_digit(some_digit)

现在来读取多张MNIST图片

需要确定每行显示多少张图片,根据照片数最多显示几行,最后一行有几个未填满,将每行进行连接起来

def plot_digits(instances,images_per_row = 10,**options):
    size = 28
    images_per_row = min(len(instances),images_per_row)
    images = [instance.reshape(size,size) for instance in instances]
    n_rows = (len(instances) - 1) // images_per_row +1
    row_images = []
    n_empty = n_rows * images_per_row - len(instances)
    images.append(np.zeros((size,size*n_empty)))
    for row in range(n_rows):
        rimages = images[row * images_per_row:(row+1) * images_per_row]
        row_images.append(np.concatenate(rimages,axis=1))
    image = np.concatenate(row_images,axis=0)
    plt.imshow(image,cmap=matplotlib.cm.binary,**options)
    plt.axis(‘off‘)
import numpy as np
import os

# to make this notebook‘s output stable across runs
np.random.seed(42)

# To plot pretty figures
%matplotlib inline
import matplotlib as mpl
import matplotlib.pyplot as plt
mpl.rc(‘axes‘, labelsize=14)
mpl.rc(‘xtick‘, labelsize=12)
mpl.rc(‘ytick‘, labelsize=12)

# Where to save the figures
PROJECT_ROOT_DIR = "."
#CHAPTER_ID = "classification"

def save_fig(fig_id, tight_layout=True):
    path = os.path.join(PROJECT_ROOT_DIR, "images", fig_id + ".png")
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=‘png‘, dpi=300)
plt.figure(figsize=(9,9))
example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]
plot_digits(example_images, images_per_row=10)
save_fig("more_digits_plot")
plt.show()

显示并将结果存入磁盘

利用BaseEstimator基类创建分类器

在做非5分类器的交叉验证时,需要写一个非5的分类器

估计器(Estimator)很多时候可以直接理解成分类器,主要包括两个函数

  • fit():训练算法,设置内部参数,接受训练集和类别两个参数
  • predict():预测测试集类别,参数为测试集

大多数sklearn估计器接受和输出的数据格式均为numpy数组或类似格式

from sklearn.base import BaseEstimator
class Never5Classifier(BaseEstimator):
    def fit(self,X,y = None):
        pass
    def predict(self,X):
        return np.zeros((len(X),1),dtype = bool)
never_5_clf = Never5Classifier()
cross_val_score(never_5_clf,X_train,y_train_5,cv = 3,scoring=‘accuracy‘)

Never5Classifier分类器预测的结果都是0,而数字为5的标签应该都为1,非5的为0,这时候可以看出也有90%的可能性猜对某张图片不是5

关于评估器以及转换器、流水线(Pipline)等更多参考:https://www.jianshu.com/p/516f009c0875

原文地址:https://www.cnblogs.com/whiteBear/p/12341094.html

时间: 2024-08-02 17:26:37

读取多张MNIST图片与利用BaseEstimator基类创建分类器的相关文章

虚基类练习:动物(利用虚基类建立一个类的多重继承,包括动物(animal,属性有体长,体重和性别),陆生动物(ter_animal,属性增加了奔跑速度),水生动物(aqu_animal,属性增加了游泳速度)和两栖动物(amp_animal)。)

Description 长期的物种进化使两栖动物既能活跃在陆地上,又能游动于水中.利用虚基类建立一个类的多重继承,包括动物(animal,属性有体长,体重和性别),陆生动物(ter_animal,属性增加了奔跑速度),水生动物(aqu_animal,属性增加了游泳速度)和两栖动物(amp_animal).其中两栖动物保留了陆生动物和水生动物的属性. Input 两栖动物的体长,体重,性别,游泳速度,奔跑速度(running_speed) Output 初始化的两栖动物的体长,体重,性别,游泳速度

MATLAB读取一张RGB图片转成YUV格式

1.读入照片 控制输出的标志定义 clc;close all;clear YES = 1; NO = 0; %YES表示输出该文件,请用户配置 yuv444_out_txt = 1; yuv444_out_yuv = 0; yuv422_out_txt = 0; yuv422_out_yuv = 0; yuv420_out_txt = 0; yuv420_out_yuv = 1; filename = 'Koala.jpg'; filestr = filename(1:findstr(filen

ASP.NET MVC with Entity Framework and CSS一书翻译系列文章之第二章:利用模型类创建视图、控制器和数据库

在这一章中,我们将直接进入项目,并且为产品和分类添加一些基本的模型类.我们将在Entity Framework的代码优先模式下,利用这些模型类创建一个数据库.我们还将学习如何在代码中创建数据库上下文类.指定数据库连接字符串以及创建一个数据库.最后,我们还将添加视图和控制器来管理和显式产品和分类数据. 注意:如果你想按照本章的代码编写示例,你必须完成第一章或者直接从www.apress.com下载第一章的源代码. 2.1 添加模型类 Entity Framework的代码优先模式允许我们从模型类创

【Android】读取sdcard上的图片

Android读取sdcard上的图片是很easy的事情,以下用一个样例来说明这个问题. 首先,在sdcard上有一张已经准备好的img25.jpg 以下,须要做的是把这张图片读取到app中显示. 做到例如以下的效果: 1.首先你要在AndroidManifest.xml申请读取sdcard的权限,增加一条语句之后,AndroidManifest.xml例如以下: <?xml version="1.0" encoding="utf-8"? > <m

一张jpg图片实际加载过程内存消耗

一张jpg图片实际加载过程内存消耗,以一张1024*1024 argb8888 500k的jpg图片为例: a.读取图片文件(消耗图片大小内存,500k)     b.解析jpg数据(cgimage, 4mb) c.释放500k的图片内存    d.opengl纹理数据(4mb)    e.释放cgimage的4mb内存.      注意,这个过程不是必然的顺序执行,释放cgimage内存的实际是有系统决定的,会很快,但是不一定是立即执行.  所以内存会瞬间飙升9mb左右,然后减少5mb,稳定到

Mapnik读取PostGIS数据渲染图片

__author__ = 'Administrator' # encoding: utf-8 import sys import datetime import mapnik m = mapnik.Map(256,256,"+proj=latlong +datum=WGS84") #m.background = mapnik.Color('steelblue') # set background colour to 'steelblue'. s = mapnik.Style() r =

一张png图片 上面有多个图标,如何用CSS准确的知道其中某个图片的坐标

一张png图片 上面有多个图标,如何用CSS准确的知道其中某个图片的坐标 ,如下图 可以使用  background background:url(images/xx.png) 40px 10px no-repeat;

Python读取excel中的图片

Python读取excel中的图片文件,并转成base64 import sys import os import xlrd import zipfile import base64 class ExcelImgRead(object): def change_file_name(self, file_path, old_name, new_type = '.zip'): """ 修改指定目录下的文件类型名 :param file_path: :param old: :par

Atitit 判断判断一张图片是否包含另一张小图片

1. keyword1 2.  模板匹配是在图像中寻找目标的方法之一(切割+图像相似度计算)1 3. 匹配效果2 4. 图片相似度的算法(感知哈希算法”(Perceptual hash algorithm)2 5. 性能结果2 6. 如何提升性能3 6.1. 可以采用简化的算法.二次匹配法,先大概确定区域3 6.2. 切割图片设置一个step3 7. 参考资料3 8. ------code3 1. keyword 图像匹配 图片是否另外一张图片的一部分 如果是标准图片,模板匹配就好 2.  模板